From f1c257845f58f2ffc7775b4b83a04cd6a67a8665 Mon Sep 17 00:00:00 2001 From: Stuart Nelson Date: Thu, 8 Jan 2015 20:58:47 +0100 Subject: [PATCH] Vendor external dependencies with godep. --- Godeps/Godeps.json | 31 + Godeps/Readme | 5 + Godeps/_workspace/.gitignore | 2 + .../p/goprotobuf/proto/Makefile | 40 + .../p/goprotobuf/proto/all_test.go | 1979 +++++++++++++ .../p/goprotobuf/proto/clone.go | 169 ++ .../p/goprotobuf/proto/clone_test.go | 186 ++ .../p/goprotobuf/proto/decode.go | 721 +++++ .../p/goprotobuf/proto/encode.go | 1054 +++++++ .../p/goprotobuf/proto/equal.go | 241 ++ .../p/goprotobuf/proto/equal_test.go | 166 ++ .../p/goprotobuf/proto/extensions.go | 351 +++ .../p/goprotobuf/proto/extensions_test.go | 60 + .../code.google.com/p/goprotobuf/proto/lib.go | 740 +++++ .../p/goprotobuf/proto/message_set.go | 229 ++ .../p/goprotobuf/proto/message_set_test.go | 66 + .../p/goprotobuf/proto/pointer_reflect.go | 384 +++ .../p/goprotobuf/proto/pointer_unsafe.go | 218 ++ .../p/goprotobuf/proto/properties.go | 658 +++++ .../p/goprotobuf/proto/size2_test.go | 63 + .../p/goprotobuf/proto/size_test.go | 120 + .../p/goprotobuf/proto/testdata/Makefile | 50 + .../goprotobuf/proto/testdata/golden_test.go | 86 + .../p/goprotobuf/proto/testdata/test.pb.go | 2350 +++++++++++++++ .../p/goprotobuf/proto/testdata/test.proto | 428 +++ .../p/goprotobuf/proto/text.go | 701 +++++ .../p/goprotobuf/proto/text_parser.go | 684 +++++ .../p/goprotobuf/proto/text_parser_test.go | 462 +++ .../p/goprotobuf/proto/text_test.go | 408 +++ .../src/github.com/golang/glog/LICENSE | 191 ++ .../src/github.com/golang/glog/README | 44 + .../src/github.com/golang/glog/glog.go | 1177 ++++++++ .../src/github.com/golang/glog/glog_file.go | 124 + .../src/github.com/golang/glog/glog_test.go | 415 +++ .../ext/all_test.go | 355 +++ .../golang_protobuf_extensions/ext/decode.go | 75 + .../golang_protobuf_extensions/ext/doc.go | 16 + .../golang_protobuf_extensions/ext/encode.go | 46 + .../ext/fixtures_test.go | 103 + .../src/github.com/miekg/dns/.gitignore | 4 + .../src/github.com/miekg/dns/.travis.yml | 19 + .../src/github.com/miekg/dns/AUTHORS | 1 + .../src/github.com/miekg/dns/CONTRIBUTORS | 9 + .../src/github.com/miekg/dns/COPYRIGHT | 9 + .../src/github.com/miekg/dns/LICENSE | 32 + .../src/github.com/miekg/dns/README.md | 140 + .../src/github.com/miekg/dns/client.go | 319 ++ .../src/github.com/miekg/dns/client_test.go | 195 ++ .../src/github.com/miekg/dns/clientconfig.go | 94 + .../src/github.com/miekg/dns/defaults.go | 242 ++ .../src/github.com/miekg/dns/dns.go | 193 ++ .../src/github.com/miekg/dns/dns_test.go | 501 ++++ .../src/github.com/miekg/dns/dnssec.go | 739 +++++ .../src/github.com/miekg/dns/dnssec_test.go | 646 +++++ .../src/github.com/miekg/dns/dyn_test.go | 3 + .../src/github.com/miekg/dns/edns.go | 501 ++++ .../src/github.com/miekg/dns/edns_test.go | 48 + .../src/github.com/miekg/dns/example_test.go | 147 + .../github.com/miekg/dns/idn/example_test.go | 18 + .../src/github.com/miekg/dns/idn/punycode.go | 268 ++ .../github.com/miekg/dns/idn/punycode_test.go | 94 + .../src/github.com/miekg/dns/keygen.go | 149 + .../src/github.com/miekg/dns/kscan.go | 244 ++ .../src/github.com/miekg/dns/labels.go | 162 ++ .../src/github.com/miekg/dns/labels_test.go | 214 ++ .../src/github.com/miekg/dns/msg.go | 1881 ++++++++++++ .../src/github.com/miekg/dns/nsecx.go | 110 + .../src/github.com/miekg/dns/nsecx_test.go | 33 + .../src/github.com/miekg/dns/parse_test.go | 1287 ++++++++ .../src/github.com/miekg/dns/privaterr.go | 122 + .../github.com/miekg/dns/privaterr_test.go | 169 ++ .../src/github.com/miekg/dns/rawmsg.go | 95 + .../src/github.com/miekg/dns/scanner.go | 43 + .../src/github.com/miekg/dns/server.go | 612 ++++ .../src/github.com/miekg/dns/server_test.go | 331 +++ .../github.com/miekg/dns/singleinflight.go | 57 + .../src/github.com/miekg/dns/tlsa.go | 84 + .../src/github.com/miekg/dns/tsig.go | 371 +++ .../src/github.com/miekg/dns/types.go | 1685 +++++++++++ .../src/github.com/miekg/dns/udp.go | 55 + .../src/github.com/miekg/dns/udp_linux.go | 63 + .../src/github.com/miekg/dns/udp_other.go | 17 + .../src/github.com/miekg/dns/udp_windows.go | 34 + .../src/github.com/miekg/dns/update.go | 132 + .../src/github.com/miekg/dns/xfr.go | 236 ++ .../src/github.com/miekg/dns/zgenerate.go | 157 + .../src/github.com/miekg/dns/zscan.go | 956 ++++++ .../src/github.com/miekg/dns/zscan_rr.go | 2178 ++++++++++++++ .../syndtr/goleveldb/leveldb/batch.go | 252 ++ .../syndtr/goleveldb/leveldb/batch_test.go | 120 + .../syndtr/goleveldb/leveldb/bench_test.go | 464 +++ .../syndtr/goleveldb/leveldb/cache/cache.go | 676 +++++ .../goleveldb/leveldb/cache/cache_test.go | 570 ++++ .../syndtr/goleveldb/leveldb/cache/lru.go | 195 ++ .../syndtr/goleveldb/leveldb/comparer.go | 75 + .../leveldb/comparer/bytes_comparer.go | 51 + .../goleveldb/leveldb/comparer/comparer.go | 57 + .../syndtr/goleveldb/leveldb/corrupt_test.go | 500 ++++ .../github.com/syndtr/goleveldb/leveldb/db.go | 943 ++++++ .../syndtr/goleveldb/leveldb/db_compaction.go | 835 ++++++ .../syndtr/goleveldb/leveldb/db_iter.go | 332 +++ .../syndtr/goleveldb/leveldb/db_snapshot.go | 183 ++ .../syndtr/goleveldb/leveldb/db_state.go | 202 ++ .../syndtr/goleveldb/leveldb/db_test.go | 2579 +++++++++++++++++ .../syndtr/goleveldb/leveldb/db_util.go | 100 + .../syndtr/goleveldb/leveldb/db_write.go | 311 ++ .../syndtr/goleveldb/leveldb/doc.go | 90 + .../syndtr/goleveldb/leveldb/errors.go | 18 + .../syndtr/goleveldb/leveldb/errors/errors.go | 76 + .../syndtr/goleveldb/leveldb/external_test.go | 58 + .../syndtr/goleveldb/leveldb/filter.go | 31 + .../syndtr/goleveldb/leveldb/filter/bloom.go | 116 + .../goleveldb/leveldb/filter/bloom_test.go | 142 + .../syndtr/goleveldb/leveldb/filter/filter.go | 60 + .../goleveldb/leveldb/go13_bench_test.go | 58 + .../goleveldb/leveldb/iterator/array_iter.go | 184 ++ .../leveldb/iterator/array_iter_test.go | 30 + .../leveldb/iterator/indexed_iter.go | 242 ++ .../leveldb/iterator/indexed_iter_test.go | 83 + .../syndtr/goleveldb/leveldb/iterator/iter.go | 131 + .../leveldb/iterator/iter_suite_test.go | 11 + .../goleveldb/leveldb/iterator/merged_iter.go | 304 ++ .../leveldb/iterator/merged_iter_test.go | 60 + .../goleveldb/leveldb/journal/journal.go | 520 ++++ .../goleveldb/leveldb/journal/journal_test.go | 818 ++++++ .../syndtr/goleveldb/leveldb/key.go | 142 + .../syndtr/goleveldb/leveldb/key_test.go | 133 + .../goleveldb/leveldb/leveldb_suite_test.go | 11 + .../goleveldb/leveldb/memdb/bench_test.go | 75 + .../syndtr/goleveldb/leveldb/memdb/memdb.go | 468 +++ .../leveldb/memdb/memdb_suite_test.go | 11 + .../goleveldb/leveldb/memdb/memdb_test.go | 135 + .../syndtr/goleveldb/leveldb/opt/options.go | 624 ++++ .../syndtr/goleveldb/leveldb/options.go | 92 + .../syndtr/goleveldb/leveldb/session.go | 455 +++ .../goleveldb/leveldb/session_record.go | 313 ++ .../goleveldb/leveldb/session_record_test.go | 64 + .../syndtr/goleveldb/leveldb/session_util.go | 249 ++ .../goleveldb/leveldb/storage/file_storage.go | 534 ++++ .../leveldb/storage/file_storage_plan9.go | 52 + .../leveldb/storage/file_storage_solaris.go | 68 + .../leveldb/storage/file_storage_test.go | 142 + .../leveldb/storage/file_storage_unix.go | 63 + .../leveldb/storage/file_storage_windows.go | 69 + .../goleveldb/leveldb/storage/mem_storage.go | 203 ++ .../leveldb/storage/mem_storage_test.go | 66 + .../goleveldb/leveldb/storage/storage.go | 157 + .../syndtr/goleveldb/leveldb/storage_test.go | 539 ++++ .../syndtr/goleveldb/leveldb/table.go | 521 ++++ .../goleveldb/leveldb/table/block_test.go | 139 + .../syndtr/goleveldb/leveldb/table/reader.go | 1107 +++++++ .../syndtr/goleveldb/leveldb/table/table.go | 177 ++ .../leveldb/table/table_suite_test.go | 11 + .../goleveldb/leveldb/table/table_test.go | 122 + .../syndtr/goleveldb/leveldb/table/writer.go | 379 +++ .../syndtr/goleveldb/leveldb/testutil/db.go | 222 ++ .../goleveldb/leveldb/testutil/ginkgo.go | 21 + .../syndtr/goleveldb/leveldb/testutil/iter.go | 327 +++ .../syndtr/goleveldb/leveldb/testutil/kv.go | 352 +++ .../goleveldb/leveldb/testutil/kvtest.go | 187 ++ .../goleveldb/leveldb/testutil/storage.go | 586 ++++ .../syndtr/goleveldb/leveldb/testutil/util.go | 171 ++ .../syndtr/goleveldb/leveldb/testutil_test.go | 63 + .../syndtr/goleveldb/leveldb/util.go | 91 + .../syndtr/goleveldb/leveldb/util/buffer.go | 293 ++ .../goleveldb/leveldb/util/buffer_pool.go | 238 ++ .../goleveldb/leveldb/util/buffer_test.go | 369 +++ .../syndtr/goleveldb/leveldb/util/crc32.go | 30 + .../syndtr/goleveldb/leveldb/util/hash.go | 48 + .../syndtr/goleveldb/leveldb/util/pool.go | 21 + .../goleveldb/leveldb/util/pool_legacy.go | 33 + .../syndtr/goleveldb/leveldb/util/range.go | 32 + .../syndtr/goleveldb/leveldb/util/util.go | 73 + .../syndtr/goleveldb/leveldb/version.go | 436 +++ .../syndtr/gosnappy/snappy/decode.go | 124 + .../syndtr/gosnappy/snappy/encode.go | 174 ++ .../syndtr/gosnappy/snappy/snappy.go | 38 + .../syndtr/gosnappy/snappy/snappy_test.go | 261 ++ Makefile | 2 + Makefile.INCLUDE | 1 + 180 files changed, 53688 insertions(+) create mode 100644 Godeps/Godeps.json create mode 100644 Godeps/Readme create mode 100644 Godeps/_workspace/.gitignore create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/Makefile create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/all_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/decode.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/encode.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/lib.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_reflect.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_unsafe.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/properties.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size2_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/Makefile create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/golden_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.pb.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.proto create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser_test.go create mode 100644 Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_test.go create mode 100644 Godeps/_workspace/src/github.com/golang/glog/LICENSE create mode 100644 Godeps/_workspace/src/github.com/golang/glog/README create mode 100644 Godeps/_workspace/src/github.com/golang/glog/glog.go create mode 100644 Godeps/_workspace/src/github.com/golang/glog/glog_file.go create mode 100644 Godeps/_workspace/src/github.com/golang/glog/glog_test.go create mode 100644 Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/all_test.go create mode 100644 Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/decode.go create mode 100644 Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/doc.go create mode 100644 Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/encode.go create mode 100644 Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/fixtures_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/.gitignore create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/.travis.yml create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/AUTHORS create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/CONTRIBUTORS create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/COPYRIGHT create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/LICENSE create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/README.md create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/client.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/client_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/clientconfig.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/defaults.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/dns.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/dns_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/dnssec.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/dnssec_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/dyn_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/edns.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/edns_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/example_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/idn/example_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/idn/punycode.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/idn/punycode_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/keygen.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/kscan.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/labels.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/labels_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/msg.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/nsecx.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/nsecx_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/parse_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/privaterr.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/privaterr_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/rawmsg.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/scanner.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/server.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/server_test.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/singleinflight.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/tlsa.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/tsig.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/types.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/udp.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/udp_linux.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/udp_other.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/udp_windows.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/update.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/xfr.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/zgenerate.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/zscan.go create mode 100644 Godeps/_workspace/src/github.com/miekg/dns/zscan_rr.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/bench_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/lru.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/corrupt_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_compaction.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_iter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_snapshot.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_state.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_util.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_write.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/doc.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors/errors.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/external_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/filter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/go13_bench_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/opt/options.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/options.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_util.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/mem_storage_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/storage.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/block_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/reader.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/writer.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/db.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/iter.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kv.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/storage.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/util.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/crc32.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/hash.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool_legacy.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/range.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/util.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/version.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/decode.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/encode.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy.go create mode 100644 Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy_test.go diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json new file mode 100644 index 0000000000..774f7caf8a --- /dev/null +++ b/Godeps/Godeps.json @@ -0,0 +1,31 @@ +{ + "ImportPath": "github.com/prometheus/prometheus", + "GoVersion": "go1.4", + "Deps": [ + { + "ImportPath": "code.google.com/p/goprotobuf/proto", + "Comment": "go.r60-152", + "Rev": "36be16571e14f67e114bb0af619e5de2c1591679" + }, + { + "ImportPath": "github.com/golang/glog", + "Rev": "44145f04b68cf362d9c4df2182967c2275eaefed" + }, + { + "ImportPath": "github.com/matttproud/golang_protobuf_extensions/ext", + "Rev": "7a864a042e844af638df17ebbabf8183dace556a" + }, + { + "ImportPath": "github.com/miekg/dns", + "Rev": "6b75215519f9916839204d80413bb178b94ef769" + }, + { + "ImportPath": "github.com/syndtr/goleveldb/leveldb", + "Rev": "63c9e642efad852f49e20a6f90194cae112fd2ac" + }, + { + "ImportPath": "github.com/syndtr/gosnappy/snappy", + "Rev": "ce8acff4829e0c2458a67ead32390ac0a381c862" + } + ] +} diff --git a/Godeps/Readme b/Godeps/Readme new file mode 100644 index 0000000000..4cdaa53d56 --- /dev/null +++ b/Godeps/Readme @@ -0,0 +1,5 @@ +This directory tree is generated automatically by godep. + +Please do not edit. + +See https://github.com/tools/godep for more information. diff --git a/Godeps/_workspace/.gitignore b/Godeps/_workspace/.gitignore new file mode 100644 index 0000000000..f037d684ef --- /dev/null +++ b/Godeps/_workspace/.gitignore @@ -0,0 +1,2 @@ +/pkg +/bin diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/Makefile b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/Makefile new file mode 100644 index 0000000000..e99b839a7d --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/Makefile @@ -0,0 +1,40 @@ +# Go support for Protocol Buffers - Google's data interchange format +# +# Copyright 2010 The Go Authors. All rights reserved. +# http://code.google.com/p/goprotobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +install: + go install + +test: install generate-test-pbs + go test + + +generate-test-pbs: + make install && cd testdata && make diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/all_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/all_test.go new file mode 100644 index 0000000000..b2a0191fcd --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/all_test.go @@ -0,0 +1,1979 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "math" + "math/rand" + "reflect" + "runtime/debug" + "strings" + "testing" + "time" + + . "./testdata" + . "code.google.com/p/goprotobuf/proto" +) + +var globalO *Buffer + +func old() *Buffer { + if globalO == nil { + globalO = NewBuffer(nil) + } + globalO.Reset() + return globalO +} + +func equalbytes(b1, b2 []byte, t *testing.T) { + if len(b1) != len(b2) { + t.Errorf("wrong lengths: 2*%d != %d", len(b1), len(b2)) + return + } + for i := 0; i < len(b1); i++ { + if b1[i] != b2[i] { + t.Errorf("bad byte[%d]:%x %x: %s %s", i, b1[i], b2[i], b1, b2) + } + } +} + +func initGoTestField() *GoTestField { + f := new(GoTestField) + f.Label = String("label") + f.Type = String("type") + return f +} + +// These are all structurally equivalent but the tag numbers differ. +// (It's remarkable that required, optional, and repeated all have +// 8 letters.) +func initGoTest_RequiredGroup() *GoTest_RequiredGroup { + return &GoTest_RequiredGroup{ + RequiredField: String("required"), + } +} + +func initGoTest_OptionalGroup() *GoTest_OptionalGroup { + return &GoTest_OptionalGroup{ + RequiredField: String("optional"), + } +} + +func initGoTest_RepeatedGroup() *GoTest_RepeatedGroup { + return &GoTest_RepeatedGroup{ + RequiredField: String("repeated"), + } +} + +func initGoTest(setdefaults bool) *GoTest { + pb := new(GoTest) + if setdefaults { + pb.F_BoolDefaulted = Bool(Default_GoTest_F_BoolDefaulted) + pb.F_Int32Defaulted = Int32(Default_GoTest_F_Int32Defaulted) + pb.F_Int64Defaulted = Int64(Default_GoTest_F_Int64Defaulted) + pb.F_Fixed32Defaulted = Uint32(Default_GoTest_F_Fixed32Defaulted) + pb.F_Fixed64Defaulted = Uint64(Default_GoTest_F_Fixed64Defaulted) + pb.F_Uint32Defaulted = Uint32(Default_GoTest_F_Uint32Defaulted) + pb.F_Uint64Defaulted = Uint64(Default_GoTest_F_Uint64Defaulted) + pb.F_FloatDefaulted = Float32(Default_GoTest_F_FloatDefaulted) + pb.F_DoubleDefaulted = Float64(Default_GoTest_F_DoubleDefaulted) + pb.F_StringDefaulted = String(Default_GoTest_F_StringDefaulted) + pb.F_BytesDefaulted = Default_GoTest_F_BytesDefaulted + pb.F_Sint32Defaulted = Int32(Default_GoTest_F_Sint32Defaulted) + pb.F_Sint64Defaulted = Int64(Default_GoTest_F_Sint64Defaulted) + } + + pb.Kind = GoTest_TIME.Enum() + pb.RequiredField = initGoTestField() + pb.F_BoolRequired = Bool(true) + pb.F_Int32Required = Int32(3) + pb.F_Int64Required = Int64(6) + pb.F_Fixed32Required = Uint32(32) + pb.F_Fixed64Required = Uint64(64) + pb.F_Uint32Required = Uint32(3232) + pb.F_Uint64Required = Uint64(6464) + pb.F_FloatRequired = Float32(3232) + pb.F_DoubleRequired = Float64(6464) + pb.F_StringRequired = String("string") + pb.F_BytesRequired = []byte("bytes") + pb.F_Sint32Required = Int32(-32) + pb.F_Sint64Required = Int64(-64) + pb.Requiredgroup = initGoTest_RequiredGroup() + + return pb +} + +func fail(msg string, b *bytes.Buffer, s string, t *testing.T) { + data := b.Bytes() + ld := len(data) + ls := len(s) / 2 + + fmt.Printf("fail %s ld=%d ls=%d\n", msg, ld, ls) + + // find the interesting spot - n + n := ls + if ld < ls { + n = ld + } + j := 0 + for i := 0; i < n; i++ { + bs := hex(s[j])*16 + hex(s[j+1]) + j += 2 + if data[i] == bs { + continue + } + n = i + break + } + l := n - 10 + if l < 0 { + l = 0 + } + h := n + 10 + + // find the interesting spot - n + fmt.Printf("is[%d]:", l) + for i := l; i < h; i++ { + if i >= ld { + fmt.Printf(" --") + continue + } + fmt.Printf(" %.2x", data[i]) + } + fmt.Printf("\n") + + fmt.Printf("sb[%d]:", l) + for i := l; i < h; i++ { + if i >= ls { + fmt.Printf(" --") + continue + } + bs := hex(s[j])*16 + hex(s[j+1]) + j += 2 + fmt.Printf(" %.2x", bs) + } + fmt.Printf("\n") + + t.Fail() + + // t.Errorf("%s: \ngood: %s\nbad: %x", msg, s, b.Bytes()) + // Print the output in a partially-decoded format; can + // be helpful when updating the test. It produces the output + // that is pasted, with minor edits, into the argument to verify(). + // data := b.Bytes() + // nesting := 0 + // for b.Len() > 0 { + // start := len(data) - b.Len() + // var u uint64 + // u, err := DecodeVarint(b) + // if err != nil { + // fmt.Printf("decode error on varint:", err) + // return + // } + // wire := u & 0x7 + // tag := u >> 3 + // switch wire { + // case WireVarint: + // v, err := DecodeVarint(b) + // if err != nil { + // fmt.Printf("decode error on varint:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" // field %d, encoding %d, value %d\n", + // data[start:len(data)-b.Len()], tag, wire, v) + // case WireFixed32: + // v, err := DecodeFixed32(b) + // if err != nil { + // fmt.Printf("decode error on fixed32:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" // field %d, encoding %d, value %d\n", + // data[start:len(data)-b.Len()], tag, wire, v) + // case WireFixed64: + // v, err := DecodeFixed64(b) + // if err != nil { + // fmt.Printf("decode error on fixed64:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" // field %d, encoding %d, value %d\n", + // data[start:len(data)-b.Len()], tag, wire, v) + // case WireBytes: + // nb, err := DecodeVarint(b) + // if err != nil { + // fmt.Printf("decode error on bytes:", err) + // return + // } + // after_tag := len(data) - b.Len() + // str := make([]byte, nb) + // _, err = b.Read(str) + // if err != nil { + // fmt.Printf("decode error on bytes:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" \"%x\" // field %d, encoding %d (FIELD)\n", + // data[start:after_tag], str, tag, wire) + // case WireStartGroup: + // nesting++ + // fmt.Printf("\t\t\"%x\"\t\t// start group field %d level %d\n", + // data[start:len(data)-b.Len()], tag, nesting) + // case WireEndGroup: + // fmt.Printf("\t\t\"%x\"\t\t// end group field %d level %d\n", + // data[start:len(data)-b.Len()], tag, nesting) + // nesting-- + // default: + // fmt.Printf("unrecognized wire type %d\n", wire) + // return + // } + // } +} + +func hex(c uint8) uint8 { + if '0' <= c && c <= '9' { + return c - '0' + } + if 'a' <= c && c <= 'f' { + return 10 + c - 'a' + } + if 'A' <= c && c <= 'F' { + return 10 + c - 'A' + } + return 0 +} + +func equal(b []byte, s string, t *testing.T) bool { + if 2*len(b) != len(s) { + // fail(fmt.Sprintf("wrong lengths: 2*%d != %d", len(b), len(s)), b, s, t) + fmt.Printf("wrong lengths: 2*%d != %d\n", len(b), len(s)) + return false + } + for i, j := 0, 0; i < len(b); i, j = i+1, j+2 { + x := hex(s[j])*16 + hex(s[j+1]) + if b[i] != x { + // fail(fmt.Sprintf("bad byte[%d]:%x %x", i, b[i], x), b, s, t) + fmt.Printf("bad byte[%d]:%x %x", i, b[i], x) + return false + } + } + return true +} + +func overify(t *testing.T, pb *GoTest, expected string) { + o := old() + err := o.Marshal(pb) + if err != nil { + fmt.Printf("overify marshal-1 err = %v", err) + o.DebugPrint("", o.Bytes()) + t.Fatalf("expected = %s", expected) + } + if !equal(o.Bytes(), expected, t) { + o.DebugPrint("overify neq 1", o.Bytes()) + t.Fatalf("expected = %s", expected) + } + + // Now test Unmarshal by recreating the original buffer. + pbd := new(GoTest) + err = o.Unmarshal(pbd) + if err != nil { + t.Fatalf("overify unmarshal err = %v", err) + o.DebugPrint("", o.Bytes()) + t.Fatalf("string = %s", expected) + } + o.Reset() + err = o.Marshal(pbd) + if err != nil { + t.Errorf("overify marshal-2 err = %v", err) + o.DebugPrint("", o.Bytes()) + t.Fatalf("string = %s", expected) + } + if !equal(o.Bytes(), expected, t) { + o.DebugPrint("overify neq 2", o.Bytes()) + t.Fatalf("string = %s", expected) + } +} + +// Simple tests for numeric encode/decode primitives (varint, etc.) +func TestNumericPrimitives(t *testing.T) { + for i := uint64(0); i < 1e6; i += 111 { + o := old() + if o.EncodeVarint(i) != nil { + t.Error("EncodeVarint") + break + } + x, e := o.DecodeVarint() + if e != nil { + t.Fatal("DecodeVarint") + } + if x != i { + t.Fatal("varint decode fail:", i, x) + } + + o = old() + if o.EncodeFixed32(i) != nil { + t.Fatal("encFixed32") + } + x, e = o.DecodeFixed32() + if e != nil { + t.Fatal("decFixed32") + } + if x != i { + t.Fatal("fixed32 decode fail:", i, x) + } + + o = old() + if o.EncodeFixed64(i*1234567) != nil { + t.Error("encFixed64") + break + } + x, e = o.DecodeFixed64() + if e != nil { + t.Error("decFixed64") + break + } + if x != i*1234567 { + t.Error("fixed64 decode fail:", i*1234567, x) + break + } + + o = old() + i32 := int32(i - 12345) + if o.EncodeZigzag32(uint64(i32)) != nil { + t.Fatal("EncodeZigzag32") + } + x, e = o.DecodeZigzag32() + if e != nil { + t.Fatal("DecodeZigzag32") + } + if x != uint64(uint32(i32)) { + t.Fatal("zigzag32 decode fail:", i32, x) + } + + o = old() + i64 := int64(i - 12345) + if o.EncodeZigzag64(uint64(i64)) != nil { + t.Fatal("EncodeZigzag64") + } + x, e = o.DecodeZigzag64() + if e != nil { + t.Fatal("DecodeZigzag64") + } + if x != uint64(i64) { + t.Fatal("zigzag64 decode fail:", i64, x) + } + } +} + +// fakeMarshaler is a simple struct implementing Marshaler and Message interfaces. +type fakeMarshaler struct { + b []byte + err error +} + +func (f fakeMarshaler) Marshal() ([]byte, error) { + return f.b, f.err +} + +func (f fakeMarshaler) String() string { + return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err) +} + +func (f fakeMarshaler) ProtoMessage() {} + +func (f fakeMarshaler) Reset() {} + +// Simple tests for proto messages that implement the Marshaler interface. +func TestMarshalerEncoding(t *testing.T) { + tests := []struct { + name string + m Message + want []byte + wantErr error + }{ + { + name: "Marshaler that fails", + m: fakeMarshaler{ + err: errors.New("some marshal err"), + b: []byte{5, 6, 7}, + }, + // Since there's an error, nothing should be written to buffer. + want: nil, + wantErr: errors.New("some marshal err"), + }, + { + name: "Marshaler that succeeds", + m: fakeMarshaler{ + b: []byte{0, 1, 2, 3, 4, 127, 255}, + }, + want: []byte{0, 1, 2, 3, 4, 127, 255}, + wantErr: nil, + }, + } + for _, test := range tests { + b := NewBuffer(nil) + err := b.Marshal(test.m) + if !reflect.DeepEqual(test.wantErr, err) { + t.Errorf("%s: got err %v wanted %v", test.name, err, test.wantErr) + } + if !reflect.DeepEqual(test.want, b.Bytes()) { + t.Errorf("%s: got bytes %v wanted %v", test.name, b.Bytes(), test.want) + } + } +} + +// Simple tests for bytes +func TestBytesPrimitives(t *testing.T) { + o := old() + bytes := []byte{'n', 'o', 'w', ' ', 'i', 's', ' ', 't', 'h', 'e', ' ', 't', 'i', 'm', 'e'} + if o.EncodeRawBytes(bytes) != nil { + t.Error("EncodeRawBytes") + } + decb, e := o.DecodeRawBytes(false) + if e != nil { + t.Error("DecodeRawBytes") + } + equalbytes(bytes, decb, t) +} + +// Simple tests for strings +func TestStringPrimitives(t *testing.T) { + o := old() + s := "now is the time" + if o.EncodeStringBytes(s) != nil { + t.Error("enc_string") + } + decs, e := o.DecodeStringBytes() + if e != nil { + t.Error("dec_string") + } + if s != decs { + t.Error("string encode/decode fail:", s, decs) + } +} + +// Do we catch the "required bit not set" case? +func TestRequiredBit(t *testing.T) { + o := old() + pb := new(GoTest) + err := o.Marshal(pb) + if err == nil { + t.Error("did not catch missing required fields") + } else if strings.Index(err.Error(), "Kind") < 0 { + t.Error("wrong error type:", err) + } +} + +// Check that all fields are nil. +// Clearly silly, and a residue from a more interesting test with an earlier, +// different initialization property, but it once caught a compiler bug so +// it lives. +func checkInitialized(pb *GoTest, t *testing.T) { + if pb.F_BoolDefaulted != nil { + t.Error("New or Reset did not set boolean:", *pb.F_BoolDefaulted) + } + if pb.F_Int32Defaulted != nil { + t.Error("New or Reset did not set int32:", *pb.F_Int32Defaulted) + } + if pb.F_Int64Defaulted != nil { + t.Error("New or Reset did not set int64:", *pb.F_Int64Defaulted) + } + if pb.F_Fixed32Defaulted != nil { + t.Error("New or Reset did not set fixed32:", *pb.F_Fixed32Defaulted) + } + if pb.F_Fixed64Defaulted != nil { + t.Error("New or Reset did not set fixed64:", *pb.F_Fixed64Defaulted) + } + if pb.F_Uint32Defaulted != nil { + t.Error("New or Reset did not set uint32:", *pb.F_Uint32Defaulted) + } + if pb.F_Uint64Defaulted != nil { + t.Error("New or Reset did not set uint64:", *pb.F_Uint64Defaulted) + } + if pb.F_FloatDefaulted != nil { + t.Error("New or Reset did not set float:", *pb.F_FloatDefaulted) + } + if pb.F_DoubleDefaulted != nil { + t.Error("New or Reset did not set double:", *pb.F_DoubleDefaulted) + } + if pb.F_StringDefaulted != nil { + t.Error("New or Reset did not set string:", *pb.F_StringDefaulted) + } + if pb.F_BytesDefaulted != nil { + t.Error("New or Reset did not set bytes:", string(pb.F_BytesDefaulted)) + } + if pb.F_Sint32Defaulted != nil { + t.Error("New or Reset did not set int32:", *pb.F_Sint32Defaulted) + } + if pb.F_Sint64Defaulted != nil { + t.Error("New or Reset did not set int64:", *pb.F_Sint64Defaulted) + } +} + +// Does Reset() reset? +func TestReset(t *testing.T) { + pb := initGoTest(true) + // muck with some values + pb.F_BoolDefaulted = Bool(false) + pb.F_Int32Defaulted = Int32(237) + pb.F_Int64Defaulted = Int64(12346) + pb.F_Fixed32Defaulted = Uint32(32000) + pb.F_Fixed64Defaulted = Uint64(666) + pb.F_Uint32Defaulted = Uint32(323232) + pb.F_Uint64Defaulted = nil + pb.F_FloatDefaulted = nil + pb.F_DoubleDefaulted = Float64(0) + pb.F_StringDefaulted = String("gotcha") + pb.F_BytesDefaulted = []byte("asdfasdf") + pb.F_Sint32Defaulted = Int32(123) + pb.F_Sint64Defaulted = Int64(789) + pb.Reset() + checkInitialized(pb, t) +} + +// All required fields set, no defaults provided. +func TestEncodeDecode1(t *testing.T) { + pb := initGoTest(false) + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 0x20 + "714000000000000000"+ // field 14, encoding 1, value 0x40 + "78a019"+ // field 15, encoding 0, value 0xca0 = 3232 + "8001c032"+ // field 16, encoding 0, value 0x1940 = 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2, string "string" + "b304"+ // field 70, encoding 3, start group + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // field 70, encoding 4, end group + "aa0605"+"6279746573"+ // field 101, encoding 2, string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f") // field 103, encoding 0, 0x7f zigzag64 +} + +// All required fields set, defaults provided. +func TestEncodeDecode2(t *testing.T) { + pb := initGoTest(true) + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All default fields set to their default value by hand +func TestEncodeDecode3(t *testing.T) { + pb := initGoTest(false) + pb.F_BoolDefaulted = Bool(true) + pb.F_Int32Defaulted = Int32(32) + pb.F_Int64Defaulted = Int64(64) + pb.F_Fixed32Defaulted = Uint32(320) + pb.F_Fixed64Defaulted = Uint64(640) + pb.F_Uint32Defaulted = Uint32(3200) + pb.F_Uint64Defaulted = Uint64(6400) + pb.F_FloatDefaulted = Float32(314159) + pb.F_DoubleDefaulted = Float64(271828) + pb.F_StringDefaulted = String("hello, \"world!\"\n") + pb.F_BytesDefaulted = []byte("Bignose") + pb.F_Sint32Defaulted = Int32(-32) + pb.F_Sint64Defaulted = Int64(-64) + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All required fields set, defaults provided, all non-defaulted optional fields have values. +func TestEncodeDecode4(t *testing.T) { + pb := initGoTest(true) + pb.Table = String("hello") + pb.Param = Int32(7) + pb.OptionalField = initGoTestField() + pb.F_BoolOptional = Bool(true) + pb.F_Int32Optional = Int32(32) + pb.F_Int64Optional = Int64(64) + pb.F_Fixed32Optional = Uint32(3232) + pb.F_Fixed64Optional = Uint64(6464) + pb.F_Uint32Optional = Uint32(323232) + pb.F_Uint64Optional = Uint64(646464) + pb.F_FloatOptional = Float32(32.) + pb.F_DoubleOptional = Float64(64.) + pb.F_StringOptional = String("hello") + pb.F_BytesOptional = []byte("Bignose") + pb.F_Sint32Optional = Int32(-32) + pb.F_Sint64Optional = Int64(-64) + pb.Optionalgroup = initGoTest_OptionalGroup() + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "1205"+"68656c6c6f"+ // field 2, encoding 2, string "hello" + "1807"+ // field 3, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "320d"+"0a056c6162656c120474797065"+ // field 6, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "f00101"+ // field 30, encoding 0, value 1 + "f80120"+ // field 31, encoding 0, value 32 + "800240"+ // field 32, encoding 0, value 64 + "8d02a00c0000"+ // field 33, encoding 5, value 3232 + "91024019000000000000"+ // field 34, encoding 1, value 6464 + "9802a0dd13"+ // field 35, encoding 0, value 323232 + "a002c0ba27"+ // field 36, encoding 0, value 646464 + "ad0200000042"+ // field 37, encoding 5, value 32.0 + "b1020000000000005040"+ // field 38, encoding 1, value 64.0 + "ba0205"+"68656c6c6f"+ // field 39, encoding 2, string "hello" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "d305"+ // start group field 90 level 1 + "da0508"+"6f7074696f6e616c"+ // field 91, encoding 2, string "optional" + "d405"+ // end group field 90 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "ea1207"+"4269676e6f7365"+ // field 301, encoding 2, string "Bignose" + "f0123f"+ // field 302, encoding 0, value 63 + "f8127f"+ // field 303, encoding 0, value 127 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All required fields set, defaults provided, all repeated fields given two values. +func TestEncodeDecode5(t *testing.T) { + pb := initGoTest(true) + pb.RepeatedField = []*GoTestField{initGoTestField(), initGoTestField()} + pb.F_BoolRepeated = []bool{false, true} + pb.F_Int32Repeated = []int32{32, 33} + pb.F_Int64Repeated = []int64{64, 65} + pb.F_Fixed32Repeated = []uint32{3232, 3333} + pb.F_Fixed64Repeated = []uint64{6464, 6565} + pb.F_Uint32Repeated = []uint32{323232, 333333} + pb.F_Uint64Repeated = []uint64{646464, 656565} + pb.F_FloatRepeated = []float32{32., 33.} + pb.F_DoubleRepeated = []float64{64., 65.} + pb.F_StringRepeated = []string{"hello", "sailor"} + pb.F_BytesRepeated = [][]byte{[]byte("big"), []byte("nose")} + pb.F_Sint32Repeated = []int32{32, -32} + pb.F_Sint64Repeated = []int64{64, -64} + pb.Repeatedgroup = []*GoTest_RepeatedGroup{initGoTest_RepeatedGroup(), initGoTest_RepeatedGroup()} + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "2a0d"+"0a056c6162656c120474797065"+ // field 5, encoding 2 (GoTestField) + "2a0d"+"0a056c6162656c120474797065"+ // field 5, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "a00100"+ // field 20, encoding 0, value 0 + "a00101"+ // field 20, encoding 0, value 1 + "a80120"+ // field 21, encoding 0, value 32 + "a80121"+ // field 21, encoding 0, value 33 + "b00140"+ // field 22, encoding 0, value 64 + "b00141"+ // field 22, encoding 0, value 65 + "bd01a00c0000"+ // field 23, encoding 5, value 3232 + "bd01050d0000"+ // field 23, encoding 5, value 3333 + "c1014019000000000000"+ // field 24, encoding 1, value 6464 + "c101a519000000000000"+ // field 24, encoding 1, value 6565 + "c801a0dd13"+ // field 25, encoding 0, value 323232 + "c80195ac14"+ // field 25, encoding 0, value 333333 + "d001c0ba27"+ // field 26, encoding 0, value 646464 + "d001b58928"+ // field 26, encoding 0, value 656565 + "dd0100000042"+ // field 27, encoding 5, value 32.0 + "dd0100000442"+ // field 27, encoding 5, value 33.0 + "e1010000000000005040"+ // field 28, encoding 1, value 64.0 + "e1010000000000405040"+ // field 28, encoding 1, value 65.0 + "ea0105"+"68656c6c6f"+ // field 29, encoding 2, string "hello" + "ea0106"+"7361696c6f72"+ // field 29, encoding 2, string "sailor" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "8305"+ // start group field 80 level 1 + "8a0508"+"7265706561746564"+ // field 81, encoding 2, string "repeated" + "8405"+ // end group field 80 level 1 + "8305"+ // start group field 80 level 1 + "8a0508"+"7265706561746564"+ // field 81, encoding 2, string "repeated" + "8405"+ // end group field 80 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "ca0c03"+"626967"+ // field 201, encoding 2, string "big" + "ca0c04"+"6e6f7365"+ // field 201, encoding 2, string "nose" + "d00c40"+ // field 202, encoding 0, value 32 + "d00c3f"+ // field 202, encoding 0, value -32 + "d80c8001"+ // field 203, encoding 0, value 64 + "d80c7f"+ // field 203, encoding 0, value -64 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All required fields set, all packed repeated fields given two values. +func TestEncodeDecode6(t *testing.T) { + pb := initGoTest(false) + pb.F_BoolRepeatedPacked = []bool{false, true} + pb.F_Int32RepeatedPacked = []int32{32, 33} + pb.F_Int64RepeatedPacked = []int64{64, 65} + pb.F_Fixed32RepeatedPacked = []uint32{3232, 3333} + pb.F_Fixed64RepeatedPacked = []uint64{6464, 6565} + pb.F_Uint32RepeatedPacked = []uint32{323232, 333333} + pb.F_Uint64RepeatedPacked = []uint64{646464, 656565} + pb.F_FloatRepeatedPacked = []float32{32., 33.} + pb.F_DoubleRepeatedPacked = []float64{64., 65.} + pb.F_Sint32RepeatedPacked = []int32{32, -32} + pb.F_Sint64RepeatedPacked = []int64{64, -64} + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "9203020001"+ // field 50, encoding 2, 2 bytes, value 0, value 1 + "9a03022021"+ // field 51, encoding 2, 2 bytes, value 32, value 33 + "a203024041"+ // field 52, encoding 2, 2 bytes, value 64, value 65 + "aa0308"+ // field 53, encoding 2, 8 bytes + "a00c0000050d0000"+ // value 3232, value 3333 + "b20310"+ // field 54, encoding 2, 16 bytes + "4019000000000000a519000000000000"+ // value 6464, value 6565 + "ba0306"+ // field 55, encoding 2, 6 bytes + "a0dd1395ac14"+ // value 323232, value 333333 + "c20306"+ // field 56, encoding 2, 6 bytes + "c0ba27b58928"+ // value 646464, value 656565 + "ca0308"+ // field 57, encoding 2, 8 bytes + "0000004200000442"+ // value 32.0, value 33.0 + "d20310"+ // field 58, encoding 2, 16 bytes + "00000000000050400000000000405040"+ // value 64.0, value 65.0 + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "b21f02"+ // field 502, encoding 2, 2 bytes + "403f"+ // value 32, value -32 + "ba1f03"+ // field 503, encoding 2, 3 bytes + "80017f") // value 64, value -64 +} + +// Test that we can encode empty bytes fields. +func TestEncodeDecodeBytes1(t *testing.T) { + pb := initGoTest(false) + + // Create our bytes + pb.F_BytesRequired = []byte{} + pb.F_BytesRepeated = [][]byte{{}} + pb.F_BytesOptional = []byte{} + + d, err := Marshal(pb) + if err != nil { + t.Error(err) + } + + pbd := new(GoTest) + if err := Unmarshal(d, pbd); err != nil { + t.Error(err) + } + + if pbd.F_BytesRequired == nil || len(pbd.F_BytesRequired) != 0 { + t.Error("required empty bytes field is incorrect") + } + if pbd.F_BytesRepeated == nil || len(pbd.F_BytesRepeated) == 1 && pbd.F_BytesRepeated[0] == nil { + t.Error("repeated empty bytes field is incorrect") + } + if pbd.F_BytesOptional == nil || len(pbd.F_BytesOptional) != 0 { + t.Error("optional empty bytes field is incorrect") + } +} + +// Test that we encode nil-valued fields of a repeated bytes field correctly. +// Since entries in a repeated field cannot be nil, nil must mean empty value. +func TestEncodeDecodeBytes2(t *testing.T) { + pb := initGoTest(false) + + // Create our bytes + pb.F_BytesRepeated = [][]byte{nil} + + d, err := Marshal(pb) + if err != nil { + t.Error(err) + } + + pbd := new(GoTest) + if err := Unmarshal(d, pbd); err != nil { + t.Error(err) + } + + if len(pbd.F_BytesRepeated) != 1 || pbd.F_BytesRepeated[0] == nil { + t.Error("Unexpected value for repeated bytes field") + } +} + +// All required fields set, defaults provided, all repeated fields given two values. +func TestSkippingUnrecognizedFields(t *testing.T) { + o := old() + pb := initGoTestField() + + // Marshal it normally. + o.Marshal(pb) + + // Now new a GoSkipTest record. + skip := &GoSkipTest{ + SkipInt32: Int32(32), + SkipFixed32: Uint32(3232), + SkipFixed64: Uint64(6464), + SkipString: String("skipper"), + Skipgroup: &GoSkipTest_SkipGroup{ + GroupInt32: Int32(75), + GroupString: String("wxyz"), + }, + } + + // Marshal it into same buffer. + o.Marshal(skip) + + pbd := new(GoTestField) + o.Unmarshal(pbd) + + // The __unrecognized field should be a marshaling of GoSkipTest + skipd := new(GoSkipTest) + + o.SetBuf(pbd.XXX_unrecognized) + o.Unmarshal(skipd) + + if *skipd.SkipInt32 != *skip.SkipInt32 { + t.Error("skip int32", skipd.SkipInt32) + } + if *skipd.SkipFixed32 != *skip.SkipFixed32 { + t.Error("skip fixed32", skipd.SkipFixed32) + } + if *skipd.SkipFixed64 != *skip.SkipFixed64 { + t.Error("skip fixed64", skipd.SkipFixed64) + } + if *skipd.SkipString != *skip.SkipString { + t.Error("skip string", *skipd.SkipString) + } + if *skipd.Skipgroup.GroupInt32 != *skip.Skipgroup.GroupInt32 { + t.Error("skip group int32", skipd.Skipgroup.GroupInt32) + } + if *skipd.Skipgroup.GroupString != *skip.Skipgroup.GroupString { + t.Error("skip group string", *skipd.Skipgroup.GroupString) + } +} + +// Check that unrecognized fields of a submessage are preserved. +func TestSubmessageUnrecognizedFields(t *testing.T) { + nm := &NewMessage{ + Nested: &NewMessage_Nested{ + Name: String("Nigel"), + FoodGroup: String("carbs"), + }, + } + b, err := Marshal(nm) + if err != nil { + t.Fatalf("Marshal of NewMessage: %v", err) + } + + // Unmarshal into an OldMessage. + om := new(OldMessage) + if err := Unmarshal(b, om); err != nil { + t.Fatalf("Unmarshal to OldMessage: %v", err) + } + exp := &OldMessage{ + Nested: &OldMessage_Nested{ + Name: String("Nigel"), + // normal protocol buffer users should not do this + XXX_unrecognized: []byte("\x12\x05carbs"), + }, + } + if !Equal(om, exp) { + t.Errorf("om = %v, want %v", om, exp) + } + + // Clone the OldMessage. + om = Clone(om).(*OldMessage) + if !Equal(om, exp) { + t.Errorf("Clone(om) = %v, want %v", om, exp) + } + + // Marshal the OldMessage, then unmarshal it into an empty NewMessage. + if b, err = Marshal(om); err != nil { + t.Fatalf("Marshal of OldMessage: %v", err) + } + t.Logf("Marshal(%v) -> %q", om, b) + nm2 := new(NewMessage) + if err := Unmarshal(b, nm2); err != nil { + t.Fatalf("Unmarshal to NewMessage: %v", err) + } + if !Equal(nm, nm2) { + t.Errorf("NewMessage round-trip: %v => %v", nm, nm2) + } +} + +// Check that an int32 field can be upgraded to an int64 field. +func TestNegativeInt32(t *testing.T) { + om := &OldMessage{ + Num: Int32(-1), + } + b, err := Marshal(om) + if err != nil { + t.Fatalf("Marshal of OldMessage: %v", err) + } + + // Check the size. It should be 11 bytes; + // 1 for the field/wire type, and 10 for the negative number. + if len(b) != 11 { + t.Errorf("%v marshaled as %q, wanted 11 bytes", om, b) + } + + // Unmarshal into a NewMessage. + nm := new(NewMessage) + if err := Unmarshal(b, nm); err != nil { + t.Fatalf("Unmarshal to NewMessage: %v", err) + } + want := &NewMessage{ + Num: Int64(-1), + } + if !Equal(nm, want) { + t.Errorf("nm = %v, want %v", nm, want) + } +} + +// Check that we can grow an array (repeated field) to have many elements. +// This test doesn't depend only on our encoding; for variety, it makes sure +// we create, encode, and decode the correct contents explicitly. It's therefore +// a bit messier. +// This test also uses (and hence tests) the Marshal/Unmarshal functions +// instead of the methods. +func TestBigRepeated(t *testing.T) { + pb := initGoTest(true) + + // Create the arrays + const N = 50 // Internally the library starts much smaller. + pb.Repeatedgroup = make([]*GoTest_RepeatedGroup, N) + pb.F_Sint64Repeated = make([]int64, N) + pb.F_Sint32Repeated = make([]int32, N) + pb.F_BytesRepeated = make([][]byte, N) + pb.F_StringRepeated = make([]string, N) + pb.F_DoubleRepeated = make([]float64, N) + pb.F_FloatRepeated = make([]float32, N) + pb.F_Uint64Repeated = make([]uint64, N) + pb.F_Uint32Repeated = make([]uint32, N) + pb.F_Fixed64Repeated = make([]uint64, N) + pb.F_Fixed32Repeated = make([]uint32, N) + pb.F_Int64Repeated = make([]int64, N) + pb.F_Int32Repeated = make([]int32, N) + pb.F_BoolRepeated = make([]bool, N) + pb.RepeatedField = make([]*GoTestField, N) + + // Fill in the arrays with checkable values. + igtf := initGoTestField() + igtrg := initGoTest_RepeatedGroup() + for i := 0; i < N; i++ { + pb.Repeatedgroup[i] = igtrg + pb.F_Sint64Repeated[i] = int64(i) + pb.F_Sint32Repeated[i] = int32(i) + s := fmt.Sprint(i) + pb.F_BytesRepeated[i] = []byte(s) + pb.F_StringRepeated[i] = s + pb.F_DoubleRepeated[i] = float64(i) + pb.F_FloatRepeated[i] = float32(i) + pb.F_Uint64Repeated[i] = uint64(i) + pb.F_Uint32Repeated[i] = uint32(i) + pb.F_Fixed64Repeated[i] = uint64(i) + pb.F_Fixed32Repeated[i] = uint32(i) + pb.F_Int64Repeated[i] = int64(i) + pb.F_Int32Repeated[i] = int32(i) + pb.F_BoolRepeated[i] = i%2 == 0 + pb.RepeatedField[i] = igtf + } + + // Marshal. + buf, _ := Marshal(pb) + + // Now test Unmarshal by recreating the original buffer. + pbd := new(GoTest) + Unmarshal(buf, pbd) + + // Check the checkable values + for i := uint64(0); i < N; i++ { + if pbd.Repeatedgroup[i] == nil { // TODO: more checking? + t.Error("pbd.Repeatedgroup bad") + } + var x uint64 + x = uint64(pbd.F_Sint64Repeated[i]) + if x != i { + t.Error("pbd.F_Sint64Repeated bad", x, i) + } + x = uint64(pbd.F_Sint32Repeated[i]) + if x != i { + t.Error("pbd.F_Sint32Repeated bad", x, i) + } + s := fmt.Sprint(i) + equalbytes(pbd.F_BytesRepeated[i], []byte(s), t) + if pbd.F_StringRepeated[i] != s { + t.Error("pbd.F_Sint32Repeated bad", pbd.F_StringRepeated[i], i) + } + x = uint64(pbd.F_DoubleRepeated[i]) + if x != i { + t.Error("pbd.F_DoubleRepeated bad", x, i) + } + x = uint64(pbd.F_FloatRepeated[i]) + if x != i { + t.Error("pbd.F_FloatRepeated bad", x, i) + } + x = pbd.F_Uint64Repeated[i] + if x != i { + t.Error("pbd.F_Uint64Repeated bad", x, i) + } + x = uint64(pbd.F_Uint32Repeated[i]) + if x != i { + t.Error("pbd.F_Uint32Repeated bad", x, i) + } + x = pbd.F_Fixed64Repeated[i] + if x != i { + t.Error("pbd.F_Fixed64Repeated bad", x, i) + } + x = uint64(pbd.F_Fixed32Repeated[i]) + if x != i { + t.Error("pbd.F_Fixed32Repeated bad", x, i) + } + x = uint64(pbd.F_Int64Repeated[i]) + if x != i { + t.Error("pbd.F_Int64Repeated bad", x, i) + } + x = uint64(pbd.F_Int32Repeated[i]) + if x != i { + t.Error("pbd.F_Int32Repeated bad", x, i) + } + if pbd.F_BoolRepeated[i] != (i%2 == 0) { + t.Error("pbd.F_BoolRepeated bad", x, i) + } + if pbd.RepeatedField[i] == nil { // TODO: more checking? + t.Error("pbd.RepeatedField bad") + } + } +} + +// Verify we give a useful message when decoding to the wrong structure type. +func TestTypeMismatch(t *testing.T) { + pb1 := initGoTest(true) + + // Marshal + o := old() + o.Marshal(pb1) + + // Now Unmarshal it to the wrong type. + pb2 := initGoTestField() + err := o.Unmarshal(pb2) + if err == nil { + t.Error("expected error, got no error") + } else if !strings.Contains(err.Error(), "bad wiretype") { + t.Error("expected bad wiretype error, got", err) + } +} + +func encodeDecode(t *testing.T, in, out Message, msg string) { + buf, err := Marshal(in) + if err != nil { + t.Fatalf("failed marshaling %v: %v", msg, err) + } + if err := Unmarshal(buf, out); err != nil { + t.Fatalf("failed unmarshaling %v: %v", msg, err) + } +} + +func TestPackedNonPackedDecoderSwitching(t *testing.T) { + np, p := new(NonPackedTest), new(PackedTest) + + // non-packed -> packed + np.A = []int32{0, 1, 1, 2, 3, 5} + encodeDecode(t, np, p, "non-packed -> packed") + if !reflect.DeepEqual(np.A, p.B) { + t.Errorf("failed non-packed -> packed; np.A=%+v, p.B=%+v", np.A, p.B) + } + + // packed -> non-packed + np.Reset() + p.B = []int32{3, 1, 4, 1, 5, 9} + encodeDecode(t, p, np, "packed -> non-packed") + if !reflect.DeepEqual(p.B, np.A) { + t.Errorf("failed packed -> non-packed; p.B=%+v, np.A=%+v", p.B, np.A) + } +} + +func TestProto1RepeatedGroup(t *testing.T) { + pb := &MessageList{ + Message: []*MessageList_Message{ + { + Name: String("blah"), + Count: Int32(7), + }, + // NOTE: pb.Message[1] is a nil + nil, + }, + } + + o := old() + if err := o.Marshal(pb); err != ErrRepeatedHasNil { + t.Fatalf("unexpected or no error when marshaling: %v", err) + } +} + +// Test that enums work. Checks for a bug introduced by making enums +// named types instead of int32: newInt32FromUint64 would crash with +// a type mismatch in reflect.PointTo. +func TestEnum(t *testing.T) { + pb := new(GoEnum) + pb.Foo = FOO_FOO1.Enum() + o := old() + if err := o.Marshal(pb); err != nil { + t.Fatal("error encoding enum:", err) + } + pb1 := new(GoEnum) + if err := o.Unmarshal(pb1); err != nil { + t.Fatal("error decoding enum:", err) + } + if *pb1.Foo != FOO_FOO1 { + t.Error("expected 7 but got ", *pb1.Foo) + } +} + +// Enum types have String methods. Check that enum fields can be printed. +// We don't care what the value actually is, just as long as it doesn't crash. +func TestPrintingNilEnumFields(t *testing.T) { + pb := new(GoEnum) + fmt.Sprintf("%+v", pb) +} + +// Verify that absent required fields cause Marshal/Unmarshal to return errors. +func TestRequiredFieldEnforcement(t *testing.T) { + pb := new(GoTestField) + _, err := Marshal(pb) + if err == nil { + t.Error("marshal: expected error, got nil") + } else if strings.Index(err.Error(), "Label") < 0 { + t.Errorf("marshal: bad error type: %v", err) + } + + // A slightly sneaky, yet valid, proto. It encodes the same required field twice, + // so simply counting the required fields is insufficient. + // field 1, encoding 2, value "hi" + buf := []byte("\x0A\x02hi\x0A\x02hi") + err = Unmarshal(buf, pb) + if err == nil { + t.Error("unmarshal: expected error, got nil") + } else if strings.Index(err.Error(), "{Unknown}") < 0 { + t.Errorf("unmarshal: bad error type: %v", err) + } +} + +func TestTypedNilMarshal(t *testing.T) { + // A typed nil should return ErrNil and not crash. + _, err := Marshal((*GoEnum)(nil)) + if err != ErrNil { + t.Errorf("Marshal: got err %v, want ErrNil", err) + } +} + +// A type that implements the Marshaler interface, but is not nillable. +type nonNillableInt uint64 + +func (nni nonNillableInt) Marshal() ([]byte, error) { + return EncodeVarint(uint64(nni)), nil +} + +type NNIMessage struct { + nni nonNillableInt +} + +func (*NNIMessage) Reset() {} +func (*NNIMessage) String() string { return "" } +func (*NNIMessage) ProtoMessage() {} + +// A type that implements the Marshaler interface and is nillable. +type nillableMessage struct { + x uint64 +} + +func (nm *nillableMessage) Marshal() ([]byte, error) { + return EncodeVarint(nm.x), nil +} + +type NMMessage struct { + nm *nillableMessage +} + +func (*NMMessage) Reset() {} +func (*NMMessage) String() string { return "" } +func (*NMMessage) ProtoMessage() {} + +// Verify a type that uses the Marshaler interface, but has a nil pointer. +func TestNilMarshaler(t *testing.T) { + // Try a struct with a Marshaler field that is nil. + // It should be directly marshable. + nmm := new(NMMessage) + if _, err := Marshal(nmm); err != nil { + t.Error("unexpected error marshaling nmm: ", err) + } + + // Try a struct with a Marshaler field that is not nillable. + nnim := new(NNIMessage) + nnim.nni = 7 + var _ Marshaler = nnim.nni // verify it is truly a Marshaler + if _, err := Marshal(nnim); err != nil { + t.Error("unexpected error marshaling nnim: ", err) + } +} + +func TestAllSetDefaults(t *testing.T) { + // Exercise SetDefaults with all scalar field types. + m := &Defaults{ + // NaN != NaN, so override that here. + F_Nan: Float32(1.7), + } + expected := &Defaults{ + F_Bool: Bool(true), + F_Int32: Int32(32), + F_Int64: Int64(64), + F_Fixed32: Uint32(320), + F_Fixed64: Uint64(640), + F_Uint32: Uint32(3200), + F_Uint64: Uint64(6400), + F_Float: Float32(314159), + F_Double: Float64(271828), + F_String: String(`hello, "world!"` + "\n"), + F_Bytes: []byte("Bignose"), + F_Sint32: Int32(-32), + F_Sint64: Int64(-64), + F_Enum: Defaults_GREEN.Enum(), + F_Pinf: Float32(float32(math.Inf(1))), + F_Ninf: Float32(float32(math.Inf(-1))), + F_Nan: Float32(1.7), + StrZero: String(""), + } + SetDefaults(m) + if !Equal(m, expected) { + t.Errorf("SetDefaults failed\n got %v\nwant %v", m, expected) + } +} + +func TestSetDefaultsWithSetField(t *testing.T) { + // Check that a set value is not overridden. + m := &Defaults{ + F_Int32: Int32(12), + } + SetDefaults(m) + if v := m.GetF_Int32(); v != 12 { + t.Errorf("m.FInt32 = %v, want 12", v) + } +} + +func TestSetDefaultsWithSubMessage(t *testing.T) { + m := &OtherMessage{ + Key: Int64(123), + Inner: &InnerMessage{ + Host: String("gopher"), + }, + } + expected := &OtherMessage{ + Key: Int64(123), + Inner: &InnerMessage{ + Host: String("gopher"), + Port: Int32(4000), + }, + } + SetDefaults(m) + if !Equal(m, expected) { + t.Errorf("\n got %v\nwant %v", m, expected) + } +} + +func TestSetDefaultsWithRepeatedSubMessage(t *testing.T) { + m := &MyMessage{ + RepInner: []*InnerMessage{{}}, + } + expected := &MyMessage{ + RepInner: []*InnerMessage{{ + Port: Int32(4000), + }}, + } + SetDefaults(m) + if !Equal(m, expected) { + t.Errorf("\n got %v\nwant %v", m, expected) + } +} + +func TestMaximumTagNumber(t *testing.T) { + m := &MaxTag{ + LastField: String("natural goat essence"), + } + buf, err := Marshal(m) + if err != nil { + t.Fatalf("proto.Marshal failed: %v", err) + } + m2 := new(MaxTag) + if err := Unmarshal(buf, m2); err != nil { + t.Fatalf("proto.Unmarshal failed: %v", err) + } + if got, want := m2.GetLastField(), *m.LastField; got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestJSON(t *testing.T) { + m := &MyMessage{ + Count: Int32(4), + Pet: []string{"bunny", "kitty"}, + Inner: &InnerMessage{ + Host: String("cauchy"), + }, + Bikeshed: MyMessage_GREEN.Enum(), + } + const expected = `{"count":4,"pet":["bunny","kitty"],"inner":{"host":"cauchy"},"bikeshed":1}` + + b, err := json.Marshal(m) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + s := string(b) + if s != expected { + t.Errorf("got %s\nwant %s", s, expected) + } + + received := new(MyMessage) + if err := json.Unmarshal(b, received); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + if !Equal(received, m) { + t.Fatalf("got %s, want %s", received, m) + } + + // Test unmarshalling of JSON with symbolic enum name. + const old = `{"count":4,"pet":["bunny","kitty"],"inner":{"host":"cauchy"},"bikeshed":"GREEN"}` + received.Reset() + if err := json.Unmarshal([]byte(old), received); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + if !Equal(received, m) { + t.Fatalf("got %s, want %s", received, m) + } +} + +func TestBadWireType(t *testing.T) { + b := []byte{7<<3 | 6} // field 7, wire type 6 + pb := new(OtherMessage) + if err := Unmarshal(b, pb); err == nil { + t.Errorf("Unmarshal did not fail") + } else if !strings.Contains(err.Error(), "unknown wire type") { + t.Errorf("wrong error: %v", err) + } +} + +func TestBytesWithInvalidLength(t *testing.T) { + // If a byte sequence has an invalid (negative) length, Unmarshal should not panic. + b := []byte{2<<3 | WireBytes, 0xff, 0xff, 0xff, 0xff, 0xff, 0} + Unmarshal(b, new(MyMessage)) +} + +func TestLengthOverflow(t *testing.T) { + // Overflowing a length should not panic. + b := []byte{2<<3 | WireBytes, 1, 1, 3<<3 | WireBytes, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0x01} + Unmarshal(b, new(MyMessage)) +} + +func TestVarintOverflow(t *testing.T) { + // Overflowing a 64-bit length should not be allowed. + b := []byte{1<<3 | WireVarint, 0x01, 3<<3 | WireBytes, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01} + if err := Unmarshal(b, new(MyMessage)); err == nil { + t.Fatalf("Overflowed uint64 length without error") + } +} + +func TestUnmarshalFuzz(t *testing.T) { + const N = 1000 + seed := time.Now().UnixNano() + t.Logf("RNG seed is %d", seed) + rng := rand.New(rand.NewSource(seed)) + buf := make([]byte, 20) + for i := 0; i < N; i++ { + for j := range buf { + buf[j] = byte(rng.Intn(256)) + } + fuzzUnmarshal(t, buf) + } +} + +func TestMergeMessages(t *testing.T) { + pb := &MessageList{Message: []*MessageList_Message{{Name: String("x"), Count: Int32(1)}}} + data, err := Marshal(pb) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + pb1 := new(MessageList) + if err := Unmarshal(data, pb1); err != nil { + t.Fatalf("first Unmarshal: %v", err) + } + if err := Unmarshal(data, pb1); err != nil { + t.Fatalf("second Unmarshal: %v", err) + } + if len(pb1.Message) != 1 { + t.Errorf("two Unmarshals produced %d Messages, want 1", len(pb1.Message)) + } + + pb2 := new(MessageList) + if err := UnmarshalMerge(data, pb2); err != nil { + t.Fatalf("first UnmarshalMerge: %v", err) + } + if err := UnmarshalMerge(data, pb2); err != nil { + t.Fatalf("second UnmarshalMerge: %v", err) + } + if len(pb2.Message) != 2 { + t.Errorf("two UnmarshalMerges produced %d Messages, want 2", len(pb2.Message)) + } +} + +func TestExtensionMarshalOrder(t *testing.T) { + m := &MyMessage{Count: Int(123)} + if err := SetExtension(m, E_Ext_More, &Ext{Data: String("alpha")}); err != nil { + t.Fatalf("SetExtension: %v", err) + } + if err := SetExtension(m, E_Ext_Text, String("aleph")); err != nil { + t.Fatalf("SetExtension: %v", err) + } + if err := SetExtension(m, E_Ext_Number, Int32(1)); err != nil { + t.Fatalf("SetExtension: %v", err) + } + + // Serialize m several times, and check we get the same bytes each time. + var orig []byte + for i := 0; i < 100; i++ { + b, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if i == 0 { + orig = b + continue + } + if !bytes.Equal(b, orig) { + t.Errorf("Bytes differ on attempt #%d", i) + } + } +} + +// Many extensions, because small maps might not iterate differently on each iteration. +var exts = []*ExtensionDesc{ + E_X201, + E_X202, + E_X203, + E_X204, + E_X205, + E_X206, + E_X207, + E_X208, + E_X209, + E_X210, + E_X211, + E_X212, + E_X213, + E_X214, + E_X215, + E_X216, + E_X217, + E_X218, + E_X219, + E_X220, + E_X221, + E_X222, + E_X223, + E_X224, + E_X225, + E_X226, + E_X227, + E_X228, + E_X229, + E_X230, + E_X231, + E_X232, + E_X233, + E_X234, + E_X235, + E_X236, + E_X237, + E_X238, + E_X239, + E_X240, + E_X241, + E_X242, + E_X243, + E_X244, + E_X245, + E_X246, + E_X247, + E_X248, + E_X249, + E_X250, +} + +func TestMessageSetMarshalOrder(t *testing.T) { + m := &MyMessageSet{} + for _, x := range exts { + if err := SetExtension(m, x, &Empty{}); err != nil { + t.Fatalf("SetExtension: %v", err) + } + } + + buf, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + // Serialize m several times, and check we get the same bytes each time. + for i := 0; i < 10; i++ { + b1, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if !bytes.Equal(b1, buf) { + t.Errorf("Bytes differ on re-Marshal #%d", i) + } + + m2 := &MyMessageSet{} + if err := Unmarshal(buf, m2); err != nil { + t.Errorf("Unmarshal: %v", err) + } + b2, err := Marshal(m2) + if err != nil { + t.Errorf("re-Marshal: %v", err) + } + if !bytes.Equal(b2, buf) { + t.Errorf("Bytes differ on round-trip #%d", i) + } + } +} + +func TestUnmarshalMergesMessages(t *testing.T) { + // If a nested message occurs twice in the input, + // the fields should be merged when decoding. + a := &OtherMessage{ + Key: Int64(123), + Inner: &InnerMessage{ + Host: String("polhode"), + Port: Int32(1234), + }, + } + aData, err := Marshal(a) + if err != nil { + t.Fatalf("Marshal(a): %v", err) + } + b := &OtherMessage{ + Weight: Float32(1.2), + Inner: &InnerMessage{ + Host: String("herpolhode"), + Connected: Bool(true), + }, + } + bData, err := Marshal(b) + if err != nil { + t.Fatalf("Marshal(b): %v", err) + } + want := &OtherMessage{ + Key: Int64(123), + Weight: Float32(1.2), + Inner: &InnerMessage{ + Host: String("herpolhode"), + Port: Int32(1234), + Connected: Bool(true), + }, + } + got := new(OtherMessage) + if err := Unmarshal(append(aData, bData...), got); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !Equal(got, want) { + t.Errorf("\n got %v\nwant %v", got, want) + } +} + +func TestEncodingSizes(t *testing.T) { + tests := []struct { + m Message + n int + }{ + {&Defaults{F_Int32: Int32(math.MaxInt32)}, 6}, + {&Defaults{F_Int32: Int32(math.MinInt32)}, 11}, + {&Defaults{F_Uint32: Uint32(uint32(math.MaxInt32) + 1)}, 6}, + {&Defaults{F_Uint32: Uint32(math.MaxUint32)}, 6}, + } + for _, test := range tests { + b, err := Marshal(test.m) + if err != nil { + t.Errorf("Marshal(%v): %v", test.m, err) + continue + } + if len(b) != test.n { + t.Errorf("Marshal(%v) yielded %d bytes, want %d bytes", test.m, len(b), test.n) + } + } +} + +func TestRequiredNotSetError(t *testing.T) { + pb := initGoTest(false) + pb.RequiredField.Label = nil + pb.F_Int32Required = nil + pb.F_Int64Required = nil + + expected := "0807" + // field 1, encoding 0, value 7 + "2206" + "120474797065" + // field 4, encoding 2 (GoTestField) + "5001" + // field 10, encoding 0, value 1 + "6d20000000" + // field 13, encoding 5, value 0x20 + "714000000000000000" + // field 14, encoding 1, value 0x40 + "78a019" + // field 15, encoding 0, value 0xca0 = 3232 + "8001c032" + // field 16, encoding 0, value 0x1940 = 6464 + "8d0100004a45" + // field 17, encoding 5, value 3232.0 + "9101000000000040b940" + // field 18, encoding 1, value 6464.0 + "9a0106" + "737472696e67" + // field 19, encoding 2, string "string" + "b304" + // field 70, encoding 3, start group + "ba0408" + "7265717569726564" + // field 71, encoding 2, string "required" + "b404" + // field 70, encoding 4, end group + "aa0605" + "6279746573" + // field 101, encoding 2, string "bytes" + "b0063f" + // field 102, encoding 0, 0x3f zigzag32 + "b8067f" // field 103, encoding 0, 0x7f zigzag64 + + o := old() + bytes, err := Marshal(pb) + if _, ok := err.(*RequiredNotSetError); !ok { + fmt.Printf("marshal-1 err = %v, want *RequiredNotSetError", err) + o.DebugPrint("", bytes) + t.Fatalf("expected = %s", expected) + } + if strings.Index(err.Error(), "RequiredField.Label") < 0 { + t.Errorf("marshal-1 wrong err msg: %v", err) + } + if !equal(bytes, expected, t) { + o.DebugPrint("neq 1", bytes) + t.Fatalf("expected = %s", expected) + } + + // Now test Unmarshal by recreating the original buffer. + pbd := new(GoTest) + err = Unmarshal(bytes, pbd) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Fatalf("unmarshal err = %v, want *RequiredNotSetError", err) + o.DebugPrint("", bytes) + t.Fatalf("string = %s", expected) + } + if strings.Index(err.Error(), "RequiredField.{Unknown}") < 0 { + t.Errorf("unmarshal wrong err msg: %v", err) + } + bytes, err = Marshal(pbd) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("marshal-2 err = %v, want *RequiredNotSetError", err) + o.DebugPrint("", bytes) + t.Fatalf("string = %s", expected) + } + if strings.Index(err.Error(), "RequiredField.Label") < 0 { + t.Errorf("marshal-2 wrong err msg: %v", err) + } + if !equal(bytes, expected, t) { + o.DebugPrint("neq 2", bytes) + t.Fatalf("string = %s", expected) + } +} + +func fuzzUnmarshal(t *testing.T, data []byte) { + defer func() { + if e := recover(); e != nil { + t.Errorf("These bytes caused a panic: %+v", data) + t.Logf("Stack:\n%s", debug.Stack()) + t.FailNow() + } + }() + + pb := new(MyMessage) + Unmarshal(data, pb) +} + +// Benchmarks + +func testMsg() *GoTest { + pb := initGoTest(true) + const N = 1000 // Internally the library starts much smaller. + pb.F_Int32Repeated = make([]int32, N) + pb.F_DoubleRepeated = make([]float64, N) + for i := 0; i < N; i++ { + pb.F_Int32Repeated[i] = int32(i) + pb.F_DoubleRepeated[i] = float64(i) + } + return pb +} + +func bytesMsg() *GoTest { + pb := initGoTest(true) + buf := make([]byte, 4000) + for i := range buf { + buf[i] = byte(i) + } + pb.F_BytesDefaulted = buf + return pb +} + +func benchmarkMarshal(b *testing.B, pb Message, marshal func(Message) ([]byte, error)) { + d, _ := marshal(pb) + b.SetBytes(int64(len(d))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + marshal(pb) + } +} + +func benchmarkBufferMarshal(b *testing.B, pb Message) { + p := NewBuffer(nil) + benchmarkMarshal(b, pb, func(pb0 Message) ([]byte, error) { + p.Reset() + err := p.Marshal(pb0) + return p.Bytes(), err + }) +} + +func benchmarkSize(b *testing.B, pb Message) { + benchmarkMarshal(b, pb, func(pb0 Message) ([]byte, error) { + Size(pb) + return nil, nil + }) +} + +func newOf(pb Message) Message { + in := reflect.ValueOf(pb) + if in.IsNil() { + return pb + } + return reflect.New(in.Type().Elem()).Interface().(Message) +} + +func benchmarkUnmarshal(b *testing.B, pb Message, unmarshal func([]byte, Message) error) { + d, _ := Marshal(pb) + b.SetBytes(int64(len(d))) + pbd := newOf(pb) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + unmarshal(d, pbd) + } +} + +func benchmarkBufferUnmarshal(b *testing.B, pb Message) { + p := NewBuffer(nil) + benchmarkUnmarshal(b, pb, func(d []byte, pb0 Message) error { + p.SetBuf(d) + return p.Unmarshal(pb0) + }) +} + +// Benchmark{Marshal,BufferMarshal,Size,Unmarshal,BufferUnmarshal}{,Bytes} + +func BenchmarkMarshal(b *testing.B) { + benchmarkMarshal(b, testMsg(), Marshal) +} + +func BenchmarkBufferMarshal(b *testing.B) { + benchmarkBufferMarshal(b, testMsg()) +} + +func BenchmarkSize(b *testing.B) { + benchmarkSize(b, testMsg()) +} + +func BenchmarkUnmarshal(b *testing.B) { + benchmarkUnmarshal(b, testMsg(), Unmarshal) +} + +func BenchmarkBufferUnmarshal(b *testing.B) { + benchmarkBufferUnmarshal(b, testMsg()) +} + +func BenchmarkMarshalBytes(b *testing.B) { + benchmarkMarshal(b, bytesMsg(), Marshal) +} + +func BenchmarkBufferMarshalBytes(b *testing.B) { + benchmarkBufferMarshal(b, bytesMsg()) +} + +func BenchmarkSizeBytes(b *testing.B) { + benchmarkSize(b, bytesMsg()) +} + +func BenchmarkUnmarshalBytes(b *testing.B) { + benchmarkUnmarshal(b, bytesMsg(), Unmarshal) +} + +func BenchmarkBufferUnmarshalBytes(b *testing.B) { + benchmarkBufferUnmarshal(b, bytesMsg()) +} + +func BenchmarkUnmarshalUnrecognizedFields(b *testing.B) { + b.StopTimer() + pb := initGoTestField() + skip := &GoSkipTest{ + SkipInt32: Int32(32), + SkipFixed32: Uint32(3232), + SkipFixed64: Uint64(6464), + SkipString: String("skipper"), + Skipgroup: &GoSkipTest_SkipGroup{ + GroupInt32: Int32(75), + GroupString: String("wxyz"), + }, + } + + pbd := new(GoTestField) + p := NewBuffer(nil) + p.Marshal(pb) + p.Marshal(skip) + p2 := NewBuffer(nil) + + b.StartTimer() + for i := 0; i < b.N; i++ { + p2.SetBuf(p.Bytes()) + p2.Unmarshal(pbd) + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone.go new file mode 100644 index 0000000000..8b94d9e186 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone.go @@ -0,0 +1,169 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Protocol buffer deep copy. +// TODO: MessageSet and RawMessage. + +package proto + +import ( + "log" + "reflect" + "strings" +) + +// Clone returns a deep copy of a protocol buffer. +func Clone(pb Message) Message { + in := reflect.ValueOf(pb) + if in.IsNil() { + return pb + } + + out := reflect.New(in.Type().Elem()) + // out is empty so a merge is a deep copy. + mergeStruct(out.Elem(), in.Elem()) + return out.Interface().(Message) +} + +// Merge merges src into dst. +// Required and optional fields that are set in src will be set to that value in dst. +// Elements of repeated fields will be appended. +// Merge panics if src and dst are not the same type, or if dst is nil. +func Merge(dst, src Message) { + in := reflect.ValueOf(src) + out := reflect.ValueOf(dst) + if out.IsNil() { + panic("proto: nil destination") + } + if in.Type() != out.Type() { + // Explicit test prior to mergeStruct so that mistyped nils will fail + panic("proto: type mismatch") + } + if in.IsNil() { + // Merging nil into non-nil is a quiet no-op + return + } + mergeStruct(out.Elem(), in.Elem()) +} + +func mergeStruct(out, in reflect.Value) { + for i := 0; i < in.NumField(); i++ { + f := in.Type().Field(i) + if strings.HasPrefix(f.Name, "XXX_") { + continue + } + mergeAny(out.Field(i), in.Field(i)) + } + + if emIn, ok := in.Addr().Interface().(extendableProto); ok { + emOut := out.Addr().Interface().(extendableProto) + mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap()) + } + + uf := in.FieldByName("XXX_unrecognized") + if !uf.IsValid() { + return + } + uin := uf.Bytes() + if len(uin) > 0 { + out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...)) + } +} + +func mergeAny(out, in reflect.Value) { + if in.Type() == protoMessageType { + if !in.IsNil() { + if out.IsNil() { + out.Set(reflect.ValueOf(Clone(in.Interface().(Message)))) + } else { + Merge(out.Interface().(Message), in.Interface().(Message)) + } + } + return + } + switch in.Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, + reflect.String, reflect.Uint32, reflect.Uint64: + out.Set(in) + case reflect.Ptr: + if in.IsNil() { + return + } + if out.IsNil() { + out.Set(reflect.New(in.Elem().Type())) + } + mergeAny(out.Elem(), in.Elem()) + case reflect.Slice: + if in.IsNil() { + return + } + n := in.Len() + if out.IsNil() { + out.Set(reflect.MakeSlice(in.Type(), 0, n)) + } + switch in.Type().Elem().Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, + reflect.String, reflect.Uint32, reflect.Uint64: + out.Set(reflect.AppendSlice(out, in)) + case reflect.Uint8: + // []byte is a scalar bytes field. + out.Set(in) + default: + for i := 0; i < n; i++ { + x := reflect.Indirect(reflect.New(in.Type().Elem())) + mergeAny(x, in.Index(i)) + out.Set(reflect.Append(out, x)) + } + } + case reflect.Struct: + mergeStruct(out, in) + default: + // unknown type, so not a protocol buffer + log.Printf("proto: don't know how to copy %v", in) + } +} + +func mergeExtension(out, in map[int32]Extension) { + for extNum, eIn := range in { + eOut := Extension{desc: eIn.desc} + if eIn.value != nil { + v := reflect.New(reflect.TypeOf(eIn.value)).Elem() + mergeAny(v, reflect.ValueOf(eIn.value)) + eOut.value = v.Interface() + } + if eIn.enc != nil { + eOut.enc = make([]byte, len(eIn.enc)) + copy(eOut.enc, eIn.enc) + } + + out[extNum] = eOut + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone_test.go new file mode 100644 index 0000000000..39aaaf7bda --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/clone_test.go @@ -0,0 +1,186 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "testing" + + "code.google.com/p/goprotobuf/proto" + + pb "./testdata" +) + +var cloneTestMessage = &pb.MyMessage{ + Count: proto.Int32(42), + Name: proto.String("Dave"), + Pet: []string{"bunny", "kitty", "horsey"}, + Inner: &pb.InnerMessage{ + Host: proto.String("niles"), + Port: proto.Int32(9099), + Connected: proto.Bool(true), + }, + Others: []*pb.OtherMessage{ + { + Value: []byte("some bytes"), + }, + }, + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(6), + }, + RepBytes: [][]byte{[]byte("sham"), []byte("wow")}, +} + +func init() { + ext := &pb.Ext{ + Data: proto.String("extension"), + } + if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_More, ext); err != nil { + panic("SetExtension: " + err.Error()) + } +} + +func TestClone(t *testing.T) { + m := proto.Clone(cloneTestMessage).(*pb.MyMessage) + if !proto.Equal(m, cloneTestMessage) { + t.Errorf("Clone(%v) = %v", cloneTestMessage, m) + } + + // Verify it was a deep copy. + *m.Inner.Port++ + if proto.Equal(m, cloneTestMessage) { + t.Error("Mutating clone changed the original") + } +} + +func TestCloneNil(t *testing.T) { + var m *pb.MyMessage + if c := proto.Clone(m); !proto.Equal(m, c) { + t.Errorf("Clone(%v) = %v", m, c) + } +} + +var mergeTests = []struct { + src, dst, want proto.Message +}{ + { + src: &pb.MyMessage{ + Count: proto.Int32(42), + }, + dst: &pb.MyMessage{ + Name: proto.String("Dave"), + }, + want: &pb.MyMessage{ + Count: proto.Int32(42), + Name: proto.String("Dave"), + }, + }, + { + src: &pb.MyMessage{ + Inner: &pb.InnerMessage{ + Host: proto.String("hey"), + Connected: proto.Bool(true), + }, + Pet: []string{"horsey"}, + Others: []*pb.OtherMessage{ + { + Value: []byte("some bytes"), + }, + }, + }, + dst: &pb.MyMessage{ + Inner: &pb.InnerMessage{ + Host: proto.String("niles"), + Port: proto.Int32(9099), + }, + Pet: []string{"bunny", "kitty"}, + Others: []*pb.OtherMessage{ + { + Key: proto.Int64(31415926535), + }, + { + // Explicitly test a src=nil field + Inner: nil, + }, + }, + }, + want: &pb.MyMessage{ + Inner: &pb.InnerMessage{ + Host: proto.String("hey"), + Connected: proto.Bool(true), + Port: proto.Int32(9099), + }, + Pet: []string{"bunny", "kitty", "horsey"}, + Others: []*pb.OtherMessage{ + { + Key: proto.Int64(31415926535), + }, + {}, + { + Value: []byte("some bytes"), + }, + }, + }, + }, + { + src: &pb.MyMessage{ + RepBytes: [][]byte{[]byte("wow")}, + }, + dst: &pb.MyMessage{ + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(6), + }, + RepBytes: [][]byte{[]byte("sham")}, + }, + want: &pb.MyMessage{ + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(6), + }, + RepBytes: [][]byte{[]byte("sham"), []byte("wow")}, + }, + }, + // Check that a scalar bytes field replaces rather than appends. + { + src: &pb.OtherMessage{Value: []byte("foo")}, + dst: &pb.OtherMessage{Value: []byte("bar")}, + want: &pb.OtherMessage{Value: []byte("foo")}, + }, +} + +func TestMerge(t *testing.T) { + for _, m := range mergeTests { + got := proto.Clone(m.dst) + proto.Merge(got, m.src) + if !proto.Equal(got, m.want) { + t.Errorf("Merge(%v, %v)\n got %v\nwant %v\n", m.dst, m.src, got, m.want) + } + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/decode.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/decode.go new file mode 100644 index 0000000000..49a487a375 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/decode.go @@ -0,0 +1,721 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for decoding protocol buffer data to construct in-memory representations. + */ + +import ( + "errors" + "fmt" + "io" + "os" + "reflect" +) + +// errOverflow is returned when an integer is too large to be represented. +var errOverflow = errors.New("proto: integer overflow") + +// The fundamental decoders that interpret bytes on the wire. +// Those that take integer types all return uint64 and are +// therefore of type valueDecoder. + +// DecodeVarint reads a varint-encoded integer from the slice. +// It returns the integer and the number of bytes consumed, or +// zero if there is not enough. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func DecodeVarint(buf []byte) (x uint64, n int) { + // x, n already 0 + for shift := uint(0); shift < 64; shift += 7 { + if n >= len(buf) { + return 0, 0 + } + b := uint64(buf[n]) + n++ + x |= (b & 0x7F) << shift + if (b & 0x80) == 0 { + return x, n + } + } + + // The number is too large to represent in a 64-bit value. + return 0, 0 +} + +// DecodeVarint reads a varint-encoded integer from the Buffer. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func (p *Buffer) DecodeVarint() (x uint64, err error) { + // x, err already 0 + + i := p.index + l := len(p.buf) + + for shift := uint(0); shift < 64; shift += 7 { + if i >= l { + err = io.ErrUnexpectedEOF + return + } + b := p.buf[i] + i++ + x |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + p.index = i + return + } + } + + // The number is too large to represent in a 64-bit value. + err = errOverflow + return +} + +// DecodeFixed64 reads a 64-bit integer from the Buffer. +// This is the format for the +// fixed64, sfixed64, and double protocol buffer types. +func (p *Buffer) DecodeFixed64() (x uint64, err error) { + // x, err already 0 + i := p.index + 8 + if i < 0 || i > len(p.buf) { + err = io.ErrUnexpectedEOF + return + } + p.index = i + + x = uint64(p.buf[i-8]) + x |= uint64(p.buf[i-7]) << 8 + x |= uint64(p.buf[i-6]) << 16 + x |= uint64(p.buf[i-5]) << 24 + x |= uint64(p.buf[i-4]) << 32 + x |= uint64(p.buf[i-3]) << 40 + x |= uint64(p.buf[i-2]) << 48 + x |= uint64(p.buf[i-1]) << 56 + return +} + +// DecodeFixed32 reads a 32-bit integer from the Buffer. +// This is the format for the +// fixed32, sfixed32, and float protocol buffer types. +func (p *Buffer) DecodeFixed32() (x uint64, err error) { + // x, err already 0 + i := p.index + 4 + if i < 0 || i > len(p.buf) { + err = io.ErrUnexpectedEOF + return + } + p.index = i + + x = uint64(p.buf[i-4]) + x |= uint64(p.buf[i-3]) << 8 + x |= uint64(p.buf[i-2]) << 16 + x |= uint64(p.buf[i-1]) << 24 + return +} + +// DecodeZigzag64 reads a zigzag-encoded 64-bit integer +// from the Buffer. +// This is the format used for the sint64 protocol buffer type. +func (p *Buffer) DecodeZigzag64() (x uint64, err error) { + x, err = p.DecodeVarint() + if err != nil { + return + } + x = (x >> 1) ^ uint64((int64(x&1)<<63)>>63) + return +} + +// DecodeZigzag32 reads a zigzag-encoded 32-bit integer +// from the Buffer. +// This is the format used for the sint32 protocol buffer type. +func (p *Buffer) DecodeZigzag32() (x uint64, err error) { + x, err = p.DecodeVarint() + if err != nil { + return + } + x = uint64((uint32(x) >> 1) ^ uint32((int32(x&1)<<31)>>31)) + return +} + +// These are not ValueDecoders: they produce an array of bytes or a string. +// bytes, embedded messages + +// DecodeRawBytes reads a count-delimited byte buffer from the Buffer. +// This is the format used for the bytes protocol buffer +// type and for embedded messages. +func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) { + n, err := p.DecodeVarint() + if err != nil { + return + } + + nb := int(n) + if nb < 0 { + return nil, fmt.Errorf("proto: bad byte length %d", nb) + } + end := p.index + nb + if end < p.index || end > len(p.buf) { + return nil, io.ErrUnexpectedEOF + } + + if !alloc { + // todo: check if can get more uses of alloc=false + buf = p.buf[p.index:end] + p.index += nb + return + } + + buf = make([]byte, nb) + copy(buf, p.buf[p.index:]) + p.index += nb + return +} + +// DecodeStringBytes reads an encoded string from the Buffer. +// This is the format used for the proto2 string type. +func (p *Buffer) DecodeStringBytes() (s string, err error) { + buf, err := p.DecodeRawBytes(false) + if err != nil { + return + } + return string(buf), nil +} + +// Skip the next item in the buffer. Its wire type is decoded and presented as an argument. +// If the protocol buffer has extensions, and the field matches, add it as an extension. +// Otherwise, if the XXX_unrecognized field exists, append the skipped data there. +func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error { + oi := o.index + + err := o.skip(t, tag, wire) + if err != nil { + return err + } + + if !unrecField.IsValid() { + return nil + } + + ptr := structPointer_Bytes(base, unrecField) + + // Add the skipped field to struct field + obuf := o.buf + + o.buf = *ptr + o.EncodeVarint(uint64(tag<<3 | wire)) + *ptr = append(o.buf, obuf[oi:o.index]...) + + o.buf = obuf + + return nil +} + +// Skip the next item in the buffer. Its wire type is decoded and presented as an argument. +func (o *Buffer) skip(t reflect.Type, tag, wire int) error { + + var u uint64 + var err error + + switch wire { + case WireVarint: + _, err = o.DecodeVarint() + case WireFixed64: + _, err = o.DecodeFixed64() + case WireBytes: + _, err = o.DecodeRawBytes(false) + case WireFixed32: + _, err = o.DecodeFixed32() + case WireStartGroup: + for { + u, err = o.DecodeVarint() + if err != nil { + break + } + fwire := int(u & 0x7) + if fwire == WireEndGroup { + break + } + ftag := int(u >> 3) + err = o.skip(t, ftag, fwire) + if err != nil { + break + } + } + default: + err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t) + } + return err +} + +// Unmarshaler is the interface representing objects that can +// unmarshal themselves. The method should reset the receiver before +// decoding starts. The argument points to data that may be +// overwritten, so implementations should not keep references to the +// buffer. +type Unmarshaler interface { + Unmarshal([]byte) error +} + +// Unmarshal parses the protocol buffer representation in buf and places the +// decoded result in pb. If the struct underlying pb does not match +// the data in buf, the results can be unpredictable. +// +// Unmarshal resets pb before starting to unmarshal, so any +// existing data in pb is always removed. Use UnmarshalMerge +// to preserve and append to existing data. +func Unmarshal(buf []byte, pb Message) error { + pb.Reset() + return UnmarshalMerge(buf, pb) +} + +// UnmarshalMerge parses the protocol buffer representation in buf and +// writes the decoded result to pb. If the struct underlying pb does not match +// the data in buf, the results can be unpredictable. +// +// UnmarshalMerge merges into existing data in pb. +// Most code should use Unmarshal instead. +func UnmarshalMerge(buf []byte, pb Message) error { + // If the object can unmarshal itself, let it. + if u, ok := pb.(Unmarshaler); ok { + return u.Unmarshal(buf) + } + return NewBuffer(buf).Unmarshal(pb) +} + +// Unmarshal parses the protocol buffer representation in the +// Buffer and places the decoded result in pb. If the struct +// underlying pb does not match the data in the buffer, the results can be +// unpredictable. +func (p *Buffer) Unmarshal(pb Message) error { + // If the object can unmarshal itself, let it. + if u, ok := pb.(Unmarshaler); ok { + err := u.Unmarshal(p.buf[p.index:]) + p.index = len(p.buf) + return err + } + + typ, base, err := getbase(pb) + if err != nil { + return err + } + + err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base) + + if collectStats { + stats.Decode++ + } + + return err +} + +// unmarshalType does the work of unmarshaling a structure. +func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error { + var state errorState + required, reqFields := prop.reqCount, uint64(0) + + var err error + for err == nil && o.index < len(o.buf) { + oi := o.index + var u uint64 + u, err = o.DecodeVarint() + if err != nil { + break + } + wire := int(u & 0x7) + if wire == WireEndGroup { + if is_group { + return nil // input is satisfied + } + return fmt.Errorf("proto: %s: wiretype end group for non-group", st) + } + tag := int(u >> 3) + if tag <= 0 { + return fmt.Errorf("proto: %s: illegal tag %d", st, tag) + } + fieldnum, ok := prop.decoderTags.get(tag) + if !ok { + // Maybe it's an extension? + if prop.extendable { + if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) { + if err = o.skip(st, tag, wire); err == nil { + ext := e.ExtensionMap()[int32(tag)] // may be missing + ext.enc = append(ext.enc, o.buf[oi:o.index]...) + e.ExtensionMap()[int32(tag)] = ext + } + continue + } + } + err = o.skipAndSave(st, tag, wire, base, prop.unrecField) + continue + } + p := prop.Prop[fieldnum] + + if p.dec == nil { + fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name) + continue + } + dec := p.dec + if wire != WireStartGroup && wire != p.WireType { + if wire == WireBytes && p.packedDec != nil { + // a packable field + dec = p.packedDec + } else { + err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType) + continue + } + } + decErr := dec(o, p, base) + if decErr != nil && !state.shouldContinue(decErr, p) { + err = decErr + } + if err == nil && p.Required { + // Successfully decoded a required field. + if tag <= 64 { + // use bitmap for fields 1-64 to catch field reuse. + var mask uint64 = 1 << uint64(tag-1) + if reqFields&mask == 0 { + // new required field + reqFields |= mask + required-- + } + } else { + // This is imprecise. It can be fooled by a required field + // with a tag > 64 that is encoded twice; that's very rare. + // A fully correct implementation would require allocating + // a data structure, which we would like to avoid. + required-- + } + } + } + if err == nil { + if is_group { + return io.ErrUnexpectedEOF + } + if state.err != nil { + return state.err + } + if required > 0 { + // Not enough information to determine the exact field. If we use extra + // CPU, we could determine the field only if the missing required field + // has a tag <= 64 and we check reqFields. + return &RequiredNotSetError{"{Unknown}"} + } + } + return err +} + +// Individual type decoders +// For each, +// u is the decoded value, +// v is a pointer to the field (pointer) in the struct + +// Sizes of the pools to allocate inside the Buffer. +// The goal is modest amortization and allocation +// on at least 16-byte boundaries. +const ( + boolPoolSize = 16 + uint32PoolSize = 8 + uint64PoolSize = 4 +) + +// Decode a bool. +func (o *Buffer) dec_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + if len(o.bools) == 0 { + o.bools = make([]bool, boolPoolSize) + } + o.bools[0] = u != 0 + *structPointer_Bool(base, p.field) = &o.bools[0] + o.bools = o.bools[1:] + return nil +} + +// Decode an int32. +func (o *Buffer) dec_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word32_Set(structPointer_Word32(base, p.field), o, uint32(u)) + return nil +} + +// Decode an int64. +func (o *Buffer) dec_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word64_Set(structPointer_Word64(base, p.field), o, u) + return nil +} + +// Decode a string. +func (o *Buffer) dec_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + sp := new(string) + *sp = s + *structPointer_String(base, p.field) = sp + return nil +} + +// Decode a slice of bytes ([]byte). +func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error { + b, err := o.DecodeRawBytes(true) + if err != nil { + return err + } + *structPointer_Bytes(base, p.field) = b + return nil +} + +// Decode a slice of bools ([]bool). +func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + v := structPointer_BoolSlice(base, p.field) + *v = append(*v, u != 0) + return nil +} + +// Decode a slice of bools ([]bool) in packed format. +func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error { + v := structPointer_BoolSlice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded bools + + y := *v + for i := 0; i < nb; i++ { + u, err := p.valDec(o) + if err != nil { + return err + } + y = append(y, u != 0) + } + + *v = y + return nil +} + +// Decode a slice of int32s ([]int32). +func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + structPointer_Word32Slice(base, p.field).Append(uint32(u)) + return nil +} + +// Decode a slice of int32s ([]int32) in packed format. +func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error { + v := structPointer_Word32Slice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded int32s + + fin := o.index + nb + if fin < o.index { + return errOverflow + } + for o.index < fin { + u, err := p.valDec(o) + if err != nil { + return err + } + v.Append(uint32(u)) + } + return nil +} + +// Decode a slice of int64s ([]int64). +func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + + structPointer_Word64Slice(base, p.field).Append(u) + return nil +} + +// Decode a slice of int64s ([]int64) in packed format. +func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error { + v := structPointer_Word64Slice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded int64s + + fin := o.index + nb + if fin < o.index { + return errOverflow + } + for o.index < fin { + u, err := p.valDec(o) + if err != nil { + return err + } + v.Append(u) + } + return nil +} + +// Decode a slice of strings ([]string). +func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + v := structPointer_StringSlice(base, p.field) + *v = append(*v, s) + return nil +} + +// Decode a slice of slice of bytes ([][]byte). +func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error { + b, err := o.DecodeRawBytes(true) + if err != nil { + return err + } + v := structPointer_BytesSlice(base, p.field) + *v = append(*v, b) + return nil +} + +// Decode a group. +func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error { + bas := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(bas) { + // allocate new nested message + bas = toStructPointer(reflect.New(p.stype)) + structPointer_SetStructPointer(base, p.field, bas) + } + return o.unmarshalType(p.stype, p.sprop, true, bas) +} + +// Decode an embedded message. +func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) { + raw, e := o.DecodeRawBytes(false) + if e != nil { + return e + } + + bas := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(bas) { + // allocate new nested message + bas = toStructPointer(reflect.New(p.stype)) + structPointer_SetStructPointer(base, p.field, bas) + } + + // If the object can unmarshal itself, let it. + if p.isUnmarshaler { + iv := structPointer_Interface(bas, p.stype) + return iv.(Unmarshaler).Unmarshal(raw) + } + + obuf := o.buf + oi := o.index + o.buf = raw + o.index = 0 + + err = o.unmarshalType(p.stype, p.sprop, false, bas) + o.buf = obuf + o.index = oi + + return err +} + +// Decode a slice of embedded messages. +func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error { + return o.dec_slice_struct(p, false, base) +} + +// Decode a slice of embedded groups. +func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error { + return o.dec_slice_struct(p, true, base) +} + +// Decode a slice of structs ([]*struct). +func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error { + v := reflect.New(p.stype) + bas := toStructPointer(v) + structPointer_StructPointerSlice(base, p.field).Append(bas) + + if is_group { + err := o.unmarshalType(p.stype, p.sprop, is_group, bas) + return err + } + + raw, err := o.DecodeRawBytes(false) + if err != nil { + return err + } + + // If the object can unmarshal itself, let it. + if p.isUnmarshaler { + iv := v.Interface() + return iv.(Unmarshaler).Unmarshal(raw) + } + + obuf := o.buf + oi := o.index + o.buf = raw + o.index = 0 + + err = o.unmarshalType(p.stype, p.sprop, is_group, bas) + + o.buf = obuf + o.index = oi + + return err +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/encode.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/encode.go new file mode 100644 index 0000000000..35e5ad56ef --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/encode.go @@ -0,0 +1,1054 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for encoding data into the wire format for protocol buffers. + */ + +import ( + "errors" + "fmt" + "reflect" + "sort" +) + +// RequiredNotSetError is the error returned if Marshal is called with +// a protocol buffer struct whose required fields have not +// all been initialized. It is also the error returned if Unmarshal is +// called with an encoded protocol buffer that does not include all the +// required fields. +// +// When printed, RequiredNotSetError reports the first unset required field in a +// message. If the field cannot be precisely determined, it is reported as +// "{Unknown}". +type RequiredNotSetError struct { + field string +} + +func (e *RequiredNotSetError) Error() string { + return fmt.Sprintf("proto: required field %q not set", e.field) +} + +var ( + // ErrRepeatedHasNil is the error returned if Marshal is called with + // a struct with a repeated field containing a nil element. + ErrRepeatedHasNil = errors.New("proto: repeated field has nil element") + + // ErrNil is the error returned if Marshal is called with nil. + ErrNil = errors.New("proto: Marshal called with nil") +) + +// The fundamental encoders that put bytes on the wire. +// Those that take integer types all accept uint64 and are +// therefore of type valueEncoder. + +const maxVarintBytes = 10 // maximum length of a varint + +// EncodeVarint returns the varint encoding of x. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +// Not used by the package itself, but helpful to clients +// wishing to use the same encoding. +func EncodeVarint(x uint64) []byte { + var buf [maxVarintBytes]byte + var n int + for n = 0; x > 127; n++ { + buf[n] = 0x80 | uint8(x&0x7F) + x >>= 7 + } + buf[n] = uint8(x) + n++ + return buf[0:n] +} + +// EncodeVarint writes a varint-encoded integer to the Buffer. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func (p *Buffer) EncodeVarint(x uint64) error { + for x >= 1<<7 { + p.buf = append(p.buf, uint8(x&0x7f|0x80)) + x >>= 7 + } + p.buf = append(p.buf, uint8(x)) + return nil +} + +func sizeVarint(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} + +// EncodeFixed64 writes a 64-bit integer to the Buffer. +// This is the format for the +// fixed64, sfixed64, and double protocol buffer types. +func (p *Buffer) EncodeFixed64(x uint64) error { + p.buf = append(p.buf, + uint8(x), + uint8(x>>8), + uint8(x>>16), + uint8(x>>24), + uint8(x>>32), + uint8(x>>40), + uint8(x>>48), + uint8(x>>56)) + return nil +} + +func sizeFixed64(x uint64) int { + return 8 +} + +// EncodeFixed32 writes a 32-bit integer to the Buffer. +// This is the format for the +// fixed32, sfixed32, and float protocol buffer types. +func (p *Buffer) EncodeFixed32(x uint64) error { + p.buf = append(p.buf, + uint8(x), + uint8(x>>8), + uint8(x>>16), + uint8(x>>24)) + return nil +} + +func sizeFixed32(x uint64) int { + return 4 +} + +// EncodeZigzag64 writes a zigzag-encoded 64-bit integer +// to the Buffer. +// This is the format used for the sint64 protocol buffer type. +func (p *Buffer) EncodeZigzag64(x uint64) error { + // use signed number to get arithmetic right shift. + return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +func sizeZigzag64(x uint64) int { + return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +// EncodeZigzag32 writes a zigzag-encoded 32-bit integer +// to the Buffer. +// This is the format used for the sint32 protocol buffer type. +func (p *Buffer) EncodeZigzag32(x uint64) error { + // use signed number to get arithmetic right shift. + return p.EncodeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31)))) +} + +func sizeZigzag32(x uint64) int { + return sizeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31)))) +} + +// EncodeRawBytes writes a count-delimited byte buffer to the Buffer. +// This is the format used for the bytes protocol buffer +// type and for embedded messages. +func (p *Buffer) EncodeRawBytes(b []byte) error { + p.EncodeVarint(uint64(len(b))) + p.buf = append(p.buf, b...) + return nil +} + +func sizeRawBytes(b []byte) int { + return sizeVarint(uint64(len(b))) + + len(b) +} + +// EncodeStringBytes writes an encoded string to the Buffer. +// This is the format used for the proto2 string type. +func (p *Buffer) EncodeStringBytes(s string) error { + p.EncodeVarint(uint64(len(s))) + p.buf = append(p.buf, s...) + return nil +} + +func sizeStringBytes(s string) int { + return sizeVarint(uint64(len(s))) + + len(s) +} + +// Marshaler is the interface representing objects that can marshal themselves. +type Marshaler interface { + Marshal() ([]byte, error) +} + +// Marshal takes the protocol buffer +// and encodes it into the wire format, returning the data. +func Marshal(pb Message) ([]byte, error) { + // Can the object marshal itself? + if m, ok := pb.(Marshaler); ok { + return m.Marshal() + } + p := NewBuffer(nil) + err := p.Marshal(pb) + var state errorState + if err != nil && !state.shouldContinue(err, nil) { + return nil, err + } + if p.buf == nil && err == nil { + // Return a non-nil slice on success. + return []byte{}, nil + } + return p.buf, err +} + +// Marshal takes the protocol buffer +// and encodes it into the wire format, writing the result to the +// Buffer. +func (p *Buffer) Marshal(pb Message) error { + // Can the object marshal itself? + if m, ok := pb.(Marshaler); ok { + data, err := m.Marshal() + if err != nil { + return err + } + p.buf = append(p.buf, data...) + return nil + } + + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return ErrNil + } + if err == nil { + err = p.enc_struct(t.Elem(), GetProperties(t.Elem()), base) + } + + if collectStats { + stats.Encode++ + } + + return err +} + +// Size returns the encoded size of a protocol buffer. +func Size(pb Message) (n int) { + // Can the object marshal itself? If so, Size is slow. + // TODO: add Size to Marshaler, or add a Sizer interface. + if m, ok := pb.(Marshaler); ok { + b, _ := m.Marshal() + return len(b) + } + + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return 0 + } + if err == nil { + n = size_struct(t.Elem(), GetProperties(t.Elem()), base) + } + + if collectStats { + stats.Size++ + } + + return +} + +// Individual type encoders. + +// Encode a bool. +func (o *Buffer) enc_bool(p *Properties, base structPointer) error { + v := *structPointer_Bool(base, p.field) + if v == nil { + return ErrNil + } + x := 0 + if *v { + x = 1 + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func size_bool(p *Properties, base structPointer) int { + v := *structPointer_Bool(base, p.field) + if v == nil { + return 0 + } + return len(p.tagcode) + 1 // each bool takes exactly one byte +} + +// Encode an int32. +func (o *Buffer) enc_int32(p *Properties, base structPointer) error { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return ErrNil + } + x := int32(word32_Get(v)) // permit sign extension to use full 64-bit range + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func size_int32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return 0 + } + x := int32(word32_Get(v)) // permit sign extension to use full 64-bit range + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +// Encode a uint32. +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_uint32(p *Properties, base structPointer) error { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return ErrNil + } + x := word32_Get(v) + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func size_uint32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return 0 + } + x := word32_Get(v) + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +// Encode an int64. +func (o *Buffer) enc_int64(p *Properties, base structPointer) error { + v := structPointer_Word64(base, p.field) + if word64_IsNil(v) { + return ErrNil + } + x := word64_Get(v) + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, x) + return nil +} + +func size_int64(p *Properties, base structPointer) (n int) { + v := structPointer_Word64(base, p.field) + if word64_IsNil(v) { + return 0 + } + x := word64_Get(v) + n += len(p.tagcode) + n += p.valSize(x) + return +} + +// Encode a string. +func (o *Buffer) enc_string(p *Properties, base structPointer) error { + v := *structPointer_String(base, p.field) + if v == nil { + return ErrNil + } + x := *v + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(x) + return nil +} + +func size_string(p *Properties, base structPointer) (n int) { + v := *structPointer_String(base, p.field) + if v == nil { + return 0 + } + x := *v + n += len(p.tagcode) + n += sizeStringBytes(x) + return +} + +// All protocol buffer fields are nillable, but be careful. +func isNil(v reflect.Value) bool { + switch v.Kind() { + case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return v.IsNil() + } + return false +} + +// Encode a message struct. +func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error { + var state errorState + structp := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(structp) { + return ErrNil + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, err := m.Marshal() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(data) + return nil + } + + o.buf = append(o.buf, p.tagcode...) + return o.enc_len_struct(p.stype, p.sprop, structp, &state) +} + +func size_struct_message(p *Properties, base structPointer) int { + structp := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(structp) { + return 0 + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, _ := m.Marshal() + n0 := len(p.tagcode) + n1 := sizeRawBytes(data) + return n0 + n1 + } + + n0 := len(p.tagcode) + n1 := size_struct(p.stype, p.sprop, structp) + n2 := sizeVarint(uint64(n1)) // size of encoded length + return n0 + n1 + n2 +} + +// Encode a group struct. +func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error { + var state errorState + b := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(b) { + return ErrNil + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) + err := o.enc_struct(p.stype, p.sprop, b) + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup)) + return state.err +} + +func size_struct_group(p *Properties, base structPointer) (n int) { + b := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(b) { + return 0 + } + + n += sizeVarint(uint64((p.Tag << 3) | WireStartGroup)) + n += size_struct(p.stype, p.sprop, b) + n += sizeVarint(uint64((p.Tag << 3) | WireEndGroup)) + return +} + +// Encode a slice of bools ([]bool). +func (o *Buffer) enc_slice_bool(p *Properties, base structPointer) error { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return ErrNil + } + for _, x := range s { + o.buf = append(o.buf, p.tagcode...) + v := uint64(0) + if x { + v = 1 + } + p.valEnc(o, v) + } + return nil +} + +func size_slice_bool(p *Properties, base structPointer) int { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return 0 + } + return l * (len(p.tagcode) + 1) // each bool takes exactly one byte +} + +// Encode a slice of bools ([]bool) in packed format. +func (o *Buffer) enc_slice_packed_bool(p *Properties, base structPointer) error { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(l)) // each bool takes exactly one byte + for _, x := range s { + v := uint64(0) + if x { + v = 1 + } + p.valEnc(o, v) + } + return nil +} + +func size_slice_packed_bool(p *Properties, base structPointer) (n int) { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return 0 + } + n += len(p.tagcode) + n += sizeVarint(uint64(l)) + n += l // each bool takes exactly one byte + return +} + +// Encode a slice of bytes ([]byte). +func (o *Buffer) enc_slice_byte(p *Properties, base structPointer) error { + s := *structPointer_Bytes(base, p.field) + if s == nil { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(s) + return nil +} + +func size_slice_byte(p *Properties, base structPointer) (n int) { + s := *structPointer_Bytes(base, p.field) + if s == nil { + return 0 + } + n += len(p.tagcode) + n += sizeRawBytes(s) + return +} + +// Encode a slice of int32s ([]int32). +func (o *Buffer) enc_slice_int32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + p.valEnc(o, uint64(x)) + } + return nil +} + +func size_slice_int32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + n += p.valSize(uint64(x)) + } + return +} + +// Encode a slice of int32s ([]int32) in packed format. +func (o *Buffer) enc_slice_packed_int32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + p.valEnc(buf, uint64(x)) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_int32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + bufSize += p.valSize(uint64(x)) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of uint32s ([]uint32). +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_slice_uint32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + x := s.Index(i) + p.valEnc(o, uint64(x)) + } + return nil +} + +func size_slice_uint32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + x := s.Index(i) + n += p.valSize(uint64(x)) + } + return +} + +// Encode a slice of uint32s ([]uint32) in packed format. +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_slice_packed_uint32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + p.valEnc(buf, uint64(s.Index(i))) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_uint32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + bufSize += p.valSize(uint64(s.Index(i))) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of int64s ([]int64). +func (o *Buffer) enc_slice_int64(p *Properties, base structPointer) error { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, s.Index(i)) + } + return nil +} + +func size_slice_int64(p *Properties, base structPointer) (n int) { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + n += p.valSize(s.Index(i)) + } + return +} + +// Encode a slice of int64s ([]int64) in packed format. +func (o *Buffer) enc_slice_packed_int64(p *Properties, base structPointer) error { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + p.valEnc(buf, s.Index(i)) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_int64(p *Properties, base structPointer) (n int) { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + bufSize += p.valSize(s.Index(i)) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of slice of bytes ([][]byte). +func (o *Buffer) enc_slice_slice_byte(p *Properties, base structPointer) error { + ss := *structPointer_BytesSlice(base, p.field) + l := len(ss) + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(ss[i]) + } + return nil +} + +func size_slice_slice_byte(p *Properties, base structPointer) (n int) { + ss := *structPointer_BytesSlice(base, p.field) + l := len(ss) + if l == 0 { + return 0 + } + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + n += sizeRawBytes(ss[i]) + } + return +} + +// Encode a slice of strings ([]string). +func (o *Buffer) enc_slice_string(p *Properties, base structPointer) error { + ss := *structPointer_StringSlice(base, p.field) + l := len(ss) + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(ss[i]) + } + return nil +} + +func size_slice_string(p *Properties, base structPointer) (n int) { + ss := *structPointer_StringSlice(base, p.field) + l := len(ss) + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + n += sizeStringBytes(ss[i]) + } + return +} + +// Encode a slice of message structs ([]*struct). +func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) error { + var state errorState + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + for i := 0; i < l; i++ { + structp := s.Index(i) + if structPointer_IsNil(structp) { + return ErrRepeatedHasNil + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, err := m.Marshal() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(data) + continue + } + + o.buf = append(o.buf, p.tagcode...) + err := o.enc_len_struct(p.stype, p.sprop, structp, &state) + if err != nil && !state.shouldContinue(err, nil) { + if err == ErrNil { + return ErrRepeatedHasNil + } + return err + } + } + return state.err +} + +func size_slice_struct_message(p *Properties, base structPointer) (n int) { + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + structp := s.Index(i) + if structPointer_IsNil(structp) { + return // return the size up to this point + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, _ := m.Marshal() + n += len(p.tagcode) + n += sizeRawBytes(data) + continue + } + + n0 := size_struct(p.stype, p.sprop, structp) + n1 := sizeVarint(uint64(n0)) // size of encoded length + n += n0 + n1 + } + return +} + +// Encode a slice of group structs ([]*struct). +func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error { + var state errorState + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + for i := 0; i < l; i++ { + b := s.Index(i) + if structPointer_IsNil(b) { + return ErrRepeatedHasNil + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) + + err := o.enc_struct(p.stype, p.sprop, b) + + if err != nil && !state.shouldContinue(err, nil) { + if err == ErrNil { + return ErrRepeatedHasNil + } + return err + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup)) + } + return state.err +} + +func size_slice_struct_group(p *Properties, base structPointer) (n int) { + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + n += l * sizeVarint(uint64((p.Tag<<3)|WireStartGroup)) + n += l * sizeVarint(uint64((p.Tag<<3)|WireEndGroup)) + for i := 0; i < l; i++ { + b := s.Index(i) + if structPointer_IsNil(b) { + return // return size up to this point + } + + n += size_struct(p.stype, p.sprop, b) + } + return +} + +// Encode an extension map. +func (o *Buffer) enc_map(p *Properties, base structPointer) error { + v := *structPointer_ExtMap(base, p.field) + if err := encodeExtensionMap(v); err != nil { + return err + } + // Fast-path for common cases: zero or one extensions. + if len(v) <= 1 { + for _, e := range v { + o.buf = append(o.buf, e.enc...) + } + return nil + } + + // Sort keys to provide a deterministic encoding. + keys := make([]int, 0, len(v)) + for k := range v { + keys = append(keys, int(k)) + } + sort.Ints(keys) + + for _, k := range keys { + o.buf = append(o.buf, v[int32(k)].enc...) + } + return nil +} + +func size_map(p *Properties, base structPointer) int { + v := *structPointer_ExtMap(base, p.field) + return sizeExtensionMap(v) +} + +// Encode a struct. +func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structPointer) error { + var state errorState + // Encode fields in tag order so that decoders may use optimizations + // that depend on the ordering. + // http://code.google.com/apis/protocolbuffers/docs/encoding.html#order + for _, i := range prop.order { + p := prop.Prop[i] + if p.enc != nil { + err := p.enc(o, p, base) + if err != nil { + if err == ErrNil { + if p.Required && state.err == nil { + state.err = &RequiredNotSetError{p.Name} + } + } else if !state.shouldContinue(err, p) { + return err + } + } + } + } + + // Add unrecognized fields at the end. + if prop.unrecField.IsValid() { + v := *structPointer_Bytes(base, prop.unrecField) + if len(v) > 0 { + o.buf = append(o.buf, v...) + } + } + + return state.err +} + +func size_struct(t reflect.Type, prop *StructProperties, base structPointer) (n int) { + for _, i := range prop.order { + p := prop.Prop[i] + if p.size != nil { + n += p.size(p, base) + } + } + + // Add unrecognized fields at the end. + if prop.unrecField.IsValid() { + v := *structPointer_Bytes(base, prop.unrecField) + n += len(v) + } + + return +} + +var zeroes [20]byte // longer than any conceivable sizeVarint + +// Encode a struct, preceded by its encoded length (as a varint). +func (o *Buffer) enc_len_struct(t reflect.Type, prop *StructProperties, base structPointer, state *errorState) error { + iLen := len(o.buf) + o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length + iMsg := len(o.buf) + err := o.enc_struct(t, prop, base) + if err != nil && !state.shouldContinue(err, nil) { + return err + } + lMsg := len(o.buf) - iMsg + lLen := sizeVarint(uint64(lMsg)) + switch x := lLen - (iMsg - iLen); { + case x > 0: // actual length is x bytes larger than the space we reserved + // Move msg x bytes right. + o.buf = append(o.buf, zeroes[:x]...) + copy(o.buf[iMsg+x:], o.buf[iMsg:iMsg+lMsg]) + case x < 0: // actual length is x bytes smaller than the space we reserved + // Move msg x bytes left. + copy(o.buf[iMsg+x:], o.buf[iMsg:iMsg+lMsg]) + o.buf = o.buf[:len(o.buf)+x] // x is negative + } + // Encode the length in the reserved space. + o.buf = o.buf[:iLen] + o.EncodeVarint(uint64(lMsg)) + o.buf = o.buf[:len(o.buf)+lMsg] + return state.err +} + +// errorState maintains the first error that occurs and updates that error +// with additional context. +type errorState struct { + err error +} + +// shouldContinue reports whether encoding should continue upon encountering the +// given error. If the error is RequiredNotSetError, shouldContinue returns true +// and, if this is the first appearance of that error, remembers it for future +// reporting. +// +// If prop is not nil, it may update any error with additional context about the +// field with the error. +func (s *errorState) shouldContinue(err error, prop *Properties) bool { + // Ignore unset required fields. + reqNotSet, ok := err.(*RequiredNotSetError) + if !ok { + return false + } + if s.err == nil { + if prop != nil { + err = &RequiredNotSetError{prop.Name + "." + reqNotSet.field} + } + s.err = err + } + return true +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal.go new file mode 100644 index 0000000000..42542e6928 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal.go @@ -0,0 +1,241 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Protocol buffer comparison. +// TODO: MessageSet. + +package proto + +import ( + "bytes" + "log" + "reflect" + "strings" +) + +/* +Equal returns true iff protocol buffers a and b are equal. +The arguments must both be pointers to protocol buffer structs. + +Equality is defined in this way: + - Two messages are equal iff they are the same type, + corresponding fields are equal, unknown field sets + are equal, and extensions sets are equal. + - Two set scalar fields are equal iff their values are equal. + If the fields are of a floating-point type, remember that + NaN != x for all x, including NaN. + - Two repeated fields are equal iff their lengths are the same, + and their corresponding elements are equal (a "bytes" field, + although represented by []byte, is not a repeated field) + - Two unset fields are equal. + - Two unknown field sets are equal if their current + encoded state is equal. (TODO) + - Two extension sets are equal iff they have corresponding + elements that are pairwise equal. + - Every other combination of things are not equal. + +The return value is undefined if a and b are not protocol buffers. +*/ +func Equal(a, b Message) bool { + if a == nil || b == nil { + return a == b + } + v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b) + if v1.Type() != v2.Type() { + return false + } + if v1.Kind() == reflect.Ptr { + if v1.IsNil() { + return v2.IsNil() + } + if v2.IsNil() { + return false + } + v1, v2 = v1.Elem(), v2.Elem() + } + if v1.Kind() != reflect.Struct { + return false + } + return equalStruct(v1, v2) +} + +// v1 and v2 are known to have the same type. +func equalStruct(v1, v2 reflect.Value) bool { + for i := 0; i < v1.NumField(); i++ { + f := v1.Type().Field(i) + if strings.HasPrefix(f.Name, "XXX_") { + continue + } + f1, f2 := v1.Field(i), v2.Field(i) + if f.Type.Kind() == reflect.Ptr { + if n1, n2 := f1.IsNil(), f2.IsNil(); n1 && n2 { + // both unset + continue + } else if n1 != n2 { + // set/unset mismatch + return false + } + b1, ok := f1.Interface().(raw) + if ok { + b2 := f2.Interface().(raw) + // RawMessage + if !bytes.Equal(b1.Bytes(), b2.Bytes()) { + return false + } + continue + } + f1, f2 = f1.Elem(), f2.Elem() + } + if !equalAny(f1, f2) { + return false + } + } + + if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { + em2 := v2.FieldByName("XXX_extensions") + if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { + return false + } + } + + uf := v1.FieldByName("XXX_unrecognized") + if !uf.IsValid() { + return true + } + + u1 := uf.Bytes() + u2 := v2.FieldByName("XXX_unrecognized").Bytes() + if !bytes.Equal(u1, u2) { + return false + } + + return true +} + +// v1 and v2 are known to have the same type. +func equalAny(v1, v2 reflect.Value) bool { + if v1.Type() == protoMessageType { + m1, _ := v1.Interface().(Message) + m2, _ := v2.Interface().(Message) + return Equal(m1, m2) + } + switch v1.Kind() { + case reflect.Bool: + return v1.Bool() == v2.Bool() + case reflect.Float32, reflect.Float64: + return v1.Float() == v2.Float() + case reflect.Int32, reflect.Int64: + return v1.Int() == v2.Int() + case reflect.Ptr: + return equalAny(v1.Elem(), v2.Elem()) + case reflect.Slice: + if v1.Type().Elem().Kind() == reflect.Uint8 { + // short circuit: []byte + if v1.IsNil() != v2.IsNil() { + return false + } + return bytes.Equal(v1.Interface().([]byte), v2.Interface().([]byte)) + } + + if v1.Len() != v2.Len() { + return false + } + for i := 0; i < v1.Len(); i++ { + if !equalAny(v1.Index(i), v2.Index(i)) { + return false + } + } + return true + case reflect.String: + return v1.Interface().(string) == v2.Interface().(string) + case reflect.Struct: + return equalStruct(v1, v2) + case reflect.Uint32, reflect.Uint64: + return v1.Uint() == v2.Uint() + } + + // unknown type, so not a protocol buffer + log.Printf("proto: don't know how to compare %v", v1) + return false +} + +// base is the struct type that the extensions are based on. +// em1 and em2 are extension maps. +func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool { + if len(em1) != len(em2) { + return false + } + + for extNum, e1 := range em1 { + e2, ok := em2[extNum] + if !ok { + return false + } + + m1, m2 := e1.value, e2.value + + if m1 != nil && m2 != nil { + // Both are unencoded. + if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) { + return false + } + continue + } + + // At least one is encoded. To do a semantically correct comparison + // we need to unmarshal them first. + var desc *ExtensionDesc + if m := extensionMaps[base]; m != nil { + desc = m[extNum] + } + if desc == nil { + log.Printf("proto: don't know how to compare extension %d of %v", extNum, base) + continue + } + var err error + if m1 == nil { + m1, err = decodeExtension(e1.enc, desc) + } + if m2 == nil && err == nil { + m2, err = decodeExtension(e2.enc, desc) + } + if err != nil { + // The encoded form is invalid. + log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err) + return false + } + if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) { + return false + } + } + + return true +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal_test.go new file mode 100644 index 0000000000..cdf3fd9c24 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/equal_test.go @@ -0,0 +1,166 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "testing" + + pb "./testdata" + . "code.google.com/p/goprotobuf/proto" +) + +// Four identical base messages. +// The init function adds extensions to some of them. +var messageWithoutExtension = &pb.MyMessage{Count: Int32(7)} +var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)} +var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)} +var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)} + +// Two messages with non-message extensions. +var messageWithInt32Extension1 = &pb.MyMessage{Count: Int32(8)} +var messageWithInt32Extension2 = &pb.MyMessage{Count: Int32(8)} + +func init() { + ext1 := &pb.Ext{Data: String("Kirk")} + ext2 := &pb.Ext{Data: String("Picard")} + + // messageWithExtension1a has ext1, but never marshals it. + if err := SetExtension(messageWithExtension1a, pb.E_Ext_More, ext1); err != nil { + panic("SetExtension on 1a failed: " + err.Error()) + } + + // messageWithExtension1b is the unmarshaled form of messageWithExtension1a. + if err := SetExtension(messageWithExtension1b, pb.E_Ext_More, ext1); err != nil { + panic("SetExtension on 1b failed: " + err.Error()) + } + buf, err := Marshal(messageWithExtension1b) + if err != nil { + panic("Marshal of 1b failed: " + err.Error()) + } + messageWithExtension1b.Reset() + if err := Unmarshal(buf, messageWithExtension1b); err != nil { + panic("Unmarshal of 1b failed: " + err.Error()) + } + + // messageWithExtension2 has ext2. + if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil { + panic("SetExtension on 2 failed: " + err.Error()) + } + + if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(23)); err != nil { + panic("SetExtension on Int32-1 failed: " + err.Error()) + } + if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(24)); err != nil { + panic("SetExtension on Int32-2 failed: " + err.Error()) + } +} + +var EqualTests = []struct { + desc string + a, b Message + exp bool +}{ + {"different types", &pb.GoEnum{}, &pb.GoTestField{}, false}, + {"equal empty", &pb.GoEnum{}, &pb.GoEnum{}, true}, + {"nil vs nil", nil, nil, true}, + {"typed nil vs typed nil", (*pb.GoEnum)(nil), (*pb.GoEnum)(nil), true}, + {"typed nil vs empty", (*pb.GoEnum)(nil), &pb.GoEnum{}, false}, + {"different typed nil", (*pb.GoEnum)(nil), (*pb.GoTestField)(nil), false}, + + {"one set field, one unset field", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{}, false}, + {"one set field zero, one unset field", &pb.GoTest{Param: Int32(0)}, &pb.GoTest{}, false}, + {"different set fields", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("bar")}, false}, + {"equal set", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("foo")}, true}, + + {"repeated, one set", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{}, false}, + {"repeated, different length", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{F_Int32Repeated: []int32{2}}, false}, + {"repeated, different value", &pb.GoTest{F_Int32Repeated: []int32{2}}, &pb.GoTest{F_Int32Repeated: []int32{3}}, false}, + {"repeated, equal", &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, true}, + {"repeated, nil equal nil", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: nil}, true}, + {"repeated, nil equal empty", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: []int32{}}, true}, + {"repeated, empty equal nil", &pb.GoTest{F_Int32Repeated: []int32{}}, &pb.GoTest{F_Int32Repeated: nil}, true}, + + { + "nested, different", + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("foo")}}, + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("bar")}}, + false, + }, + { + "nested, equal", + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}}, + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}}, + true, + }, + + {"bytes", &pb.OtherMessage{Value: []byte("foo")}, &pb.OtherMessage{Value: []byte("foo")}, true}, + {"bytes, empty", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: []byte{}}, true}, + {"bytes, empty vs nil", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: nil}, false}, + { + "repeated bytes", + &pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}}, + &pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}}, + true, + }, + + {"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false}, + {"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true}, + {"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false}, + + {"int32 extension vs. itself", messageWithInt32Extension1, messageWithInt32Extension1, true}, + {"int32 extension vs. a different int32", messageWithInt32Extension1, messageWithInt32Extension2, false}, + + { + "message with group", + &pb.MyMessage{ + Count: Int32(1), + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: Int32(5), + }, + }, + &pb.MyMessage{ + Count: Int32(1), + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: Int32(5), + }, + }, + true, + }, +} + +func TestEqual(t *testing.T) { + for _, tc := range EqualTests { + if res := Equal(tc.a, tc.b); res != tc.exp { + t.Errorf("%v: Equal(%v, %v) = %v, want %v", tc.desc, tc.a, tc.b, res, tc.exp) + } + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions.go new file mode 100644 index 0000000000..e592053c51 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions.go @@ -0,0 +1,351 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Types and routines for supporting protocol buffer extensions. + */ + +import ( + "errors" + "reflect" + "strconv" + "sync" +) + +// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. +var ErrMissingExtension = errors.New("proto: missing extension") + +// ExtensionRange represents a range of message extensions for a protocol buffer. +// Used in code generated by the protocol compiler. +type ExtensionRange struct { + Start, End int32 // both inclusive +} + +// extendableProto is an interface implemented by any protocol buffer that may be extended. +type extendableProto interface { + Message + ExtensionRangeArray() []ExtensionRange + ExtensionMap() map[int32]Extension +} + +var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() + +// ExtensionDesc represents an extension specification. +// Used in generated code from the protocol compiler. +type ExtensionDesc struct { + ExtendedType Message // nil pointer to the type that is being extended + ExtensionType interface{} // nil pointer to the extension type + Field int32 // field number + Name string // fully-qualified name of extension, for text formatting + Tag string // protobuf tag style +} + +func (ed *ExtensionDesc) repeated() bool { + t := reflect.TypeOf(ed.ExtensionType) + return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 +} + +// Extension represents an extension in a message. +type Extension struct { + // When an extension is stored in a message using SetExtension + // only desc and value are set. When the message is marshaled + // enc will be set to the encoded form of the message. + // + // When a message is unmarshaled and contains extensions, each + // extension will have only enc set. When such an extension is + // accessed using GetExtension (or GetExtensions) desc and value + // will be set. + desc *ExtensionDesc + value interface{} + enc []byte +} + +// SetRawExtension is for testing only. +func SetRawExtension(base extendableProto, id int32, b []byte) { + base.ExtensionMap()[id] = Extension{enc: b} +} + +// isExtensionField returns true iff the given field number is in an extension range. +func isExtensionField(pb extendableProto, field int32) bool { + for _, er := range pb.ExtensionRangeArray() { + if er.Start <= field && field <= er.End { + return true + } + } + return false +} + +// checkExtensionTypes checks that the given extension is valid for pb. +func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { + // Check the extended type. + if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { + return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) + } + // Check the range. + if !isExtensionField(pb, extension.Field) { + return errors.New("proto: bad extension number; not in declared ranges") + } + return nil +} + +// extPropKey is sufficient to uniquely identify an extension. +type extPropKey struct { + base reflect.Type + field int32 +} + +var extProp = struct { + sync.RWMutex + m map[extPropKey]*Properties +}{ + m: make(map[extPropKey]*Properties), +} + +func extensionProperties(ed *ExtensionDesc) *Properties { + key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} + + extProp.RLock() + if prop, ok := extProp.m[key]; ok { + extProp.RUnlock() + return prop + } + extProp.RUnlock() + + extProp.Lock() + defer extProp.Unlock() + // Check again. + if prop, ok := extProp.m[key]; ok { + return prop + } + + prop := new(Properties) + prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) + extProp.m[key] = prop + return prop +} + +// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. +func encodeExtensionMap(m map[int32]Extension) error { + for k, e := range m { + if e.value == nil || e.desc == nil { + // Extension is only in its encoded form. + continue + } + + // We don't skip extensions that have an encoded form set, + // because the extension value may have been mutated after + // the last time this function was called. + + et := reflect.TypeOf(e.desc.ExtensionType) + props := extensionProperties(e.desc) + + p := NewBuffer(nil) + // If e.value has type T, the encoder expects a *struct{ X T }. + // Pass a *T with a zero field and hope it all works out. + x := reflect.New(et) + x.Elem().Set(reflect.ValueOf(e.value)) + if err := props.enc(p, props, toStructPointer(x)); err != nil { + return err + } + e.enc = p.buf + m[k] = e + } + return nil +} + +func sizeExtensionMap(m map[int32]Extension) (n int) { + for _, e := range m { + if e.value == nil || e.desc == nil { + // Extension is only in its encoded form. + n += len(e.enc) + continue + } + + // We don't skip extensions that have an encoded form set, + // because the extension value may have been mutated after + // the last time this function was called. + + et := reflect.TypeOf(e.desc.ExtensionType) + props := extensionProperties(e.desc) + + // If e.value has type T, the encoder expects a *struct{ X T }. + // Pass a *T with a zero field and hope it all works out. + x := reflect.New(et) + x.Elem().Set(reflect.ValueOf(e.value)) + n += props.size(props, toStructPointer(x)) + } + return +} + +// HasExtension returns whether the given extension is present in pb. +func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { + // TODO: Check types, field numbers, etc.? + _, ok := pb.ExtensionMap()[extension.Field] + return ok +} + +// ClearExtension removes the given extension from pb. +func ClearExtension(pb extendableProto, extension *ExtensionDesc) { + // TODO: Check types, field numbers, etc.? + delete(pb.ExtensionMap(), extension.Field) +} + +// GetExtension parses and returns the given extension of pb. +// If the extension is not present it returns ErrMissingExtension. +func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { + if err := checkExtensionTypes(pb, extension); err != nil { + return nil, err + } + + e, ok := pb.ExtensionMap()[extension.Field] + if !ok { + return nil, ErrMissingExtension + } + if e.value != nil { + // Already decoded. Check the descriptor, though. + if e.desc != extension { + // This shouldn't happen. If it does, it means that + // GetExtension was called twice with two different + // descriptors with the same field number. + return nil, errors.New("proto: descriptor conflict") + } + return e.value, nil + } + + v, err := decodeExtension(e.enc, extension) + if err != nil { + return nil, err + } + + // Remember the decoded version and drop the encoded version. + // That way it is safe to mutate what we return. + e.value = v + e.desc = extension + e.enc = nil + return e.value, nil +} + +// decodeExtension decodes an extension encoded in b. +func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { + o := NewBuffer(b) + + t := reflect.TypeOf(extension.ExtensionType) + rep := extension.repeated() + + props := extensionProperties(extension) + + // t is a pointer to a struct, pointer to basic type or a slice. + // Allocate a "field" to store the pointer/slice itself; the + // pointer/slice will be stored here. We pass + // the address of this field to props.dec. + // This passes a zero field and a *t and lets props.dec + // interpret it as a *struct{ x t }. + value := reflect.New(t).Elem() + + for { + // Discard wire type and field number varint. It isn't needed. + if _, err := o.DecodeVarint(); err != nil { + return nil, err + } + + if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil { + return nil, err + } + + if !rep || o.index >= len(o.buf) { + break + } + } + return value.Interface(), nil +} + +// GetExtensions returns a slice of the extensions present in pb that are also listed in es. +// The returned slice has the same length as es; missing extensions will appear as nil elements. +func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { + epb, ok := pb.(extendableProto) + if !ok { + err = errors.New("proto: not an extendable proto") + return + } + extensions = make([]interface{}, len(es)) + for i, e := range es { + extensions[i], err = GetExtension(epb, e) + if err == ErrMissingExtension { + err = nil + } + if err != nil { + return + } + } + return +} + +// SetExtension sets the specified extension of pb to the specified value. +func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { + if err := checkExtensionTypes(pb, extension); err != nil { + return err + } + typ := reflect.TypeOf(extension.ExtensionType) + if typ != reflect.TypeOf(value) { + return errors.New("proto: bad extension value type") + } + + pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} + return nil +} + +// A global registry of extensions. +// The generated code will register the generated descriptors by calling RegisterExtension. + +var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) + +// RegisterExtension is called from the generated code. +func RegisterExtension(desc *ExtensionDesc) { + st := reflect.TypeOf(desc.ExtendedType).Elem() + m := extensionMaps[st] + if m == nil { + m = make(map[int32]*ExtensionDesc) + extensionMaps[st] = m + } + if _, ok := m[desc.Field]; ok { + panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) + } + m[desc.Field] = desc +} + +// RegisteredExtensions returns a map of the registered extensions of a +// protocol buffer struct, indexed by the extension number. +// The argument pb should be a nil pointer to the struct type. +func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { + return extensionMaps[reflect.TypeOf(pb).Elem()] +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions_test.go new file mode 100644 index 0000000000..03f756b6f8 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/extensions_test.go @@ -0,0 +1,60 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2014 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "testing" + + pb "./testdata" + "code.google.com/p/goprotobuf/proto" +) + +func TestGetExtensionsWithMissingExtensions(t *testing.T) { + msg := &pb.MyMessage{} + ext1 := &pb.Ext{} + if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { + t.Fatalf("Could not set ext1: %s", ext1) + } + exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ + pb.E_Ext_More, + pb.E_Ext_Text, + }) + if err != nil { + t.Fatalf("GetExtensions() failed: %s", err) + } + if exts[0] != ext1 { + t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0]) + } + if exts[1] != nil { + t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1]) + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/lib.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/lib.go new file mode 100644 index 0000000000..46a4416fdf --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/lib.go @@ -0,0 +1,740 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* + Package proto converts data structures to and from the wire format of + protocol buffers. It works in concert with the Go source code generated + for .proto files by the protocol compiler. + + A summary of the properties of the protocol buffer interface + for a protocol buffer variable v: + + - Names are turned from camel_case to CamelCase for export. + - There are no methods on v to set fields; just treat + them as structure fields. + - There are getters that return a field's value if set, + and return the field's default value if unset. + The getters work even if the receiver is a nil message. + - The zero value for a struct is its correct initialization state. + All desired fields must be set before marshaling. + - A Reset() method will restore a protobuf struct to its zero state. + - Non-repeated fields are pointers to the values; nil means unset. + That is, optional or required field int32 f becomes F *int32. + - Repeated fields are slices. + - Helper functions are available to aid the setting of fields. + Helpers for getting values are superseded by the + GetFoo methods and their use is deprecated. + msg.Foo = proto.String("hello") // set field + - Constants are defined to hold the default values of all fields that + have them. They have the form Default_StructName_FieldName. + Because the getter methods handle defaulted values, + direct use of these constants should be rare. + - Enums are given type names and maps from names to values. + Enum values are prefixed with the enum's type name. Enum types have + a String method, and a Enum method to assist in message construction. + - Nested groups and enums have type names prefixed with the name of + the surrounding message type. + - Extensions are given descriptor names that start with E_, + followed by an underscore-delimited list of the nested messages + that contain it (if any) followed by the CamelCased name of the + extension field itself. HasExtension, ClearExtension, GetExtension + and SetExtension are functions for manipulating extensions. + - Marshal and Unmarshal are functions to encode and decode the wire format. + + The simplest way to describe this is to see an example. + Given file test.proto, containing + + package example; + + enum FOO { X = 17; }; + + message Test { + required string label = 1; + optional int32 type = 2 [default=77]; + repeated int64 reps = 3; + optional group OptionalGroup = 4 { + required string RequiredField = 5; + } + } + + The resulting file, test.pb.go, is: + + package example + + import "code.google.com/p/goprotobuf/proto" + + type FOO int32 + const ( + FOO_X FOO = 17 + ) + var FOO_name = map[int32]string{ + 17: "X", + } + var FOO_value = map[string]int32{ + "X": 17, + } + + func (x FOO) Enum() *FOO { + p := new(FOO) + *p = x + return p + } + func (x FOO) String() string { + return proto.EnumName(FOO_name, int32(x)) + } + + type Test struct { + Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` + Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` + Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` + Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` + XXX_unrecognized []byte `json:"-"` + } + func (this *Test) Reset() { *this = Test{} } + func (this *Test) String() string { return proto.CompactTextString(this) } + const Default_Test_Type int32 = 77 + + func (this *Test) GetLabel() string { + if this != nil && this.Label != nil { + return *this.Label + } + return "" + } + + func (this *Test) GetType() int32 { + if this != nil && this.Type != nil { + return *this.Type + } + return Default_Test_Type + } + + func (this *Test) GetOptionalgroup() *Test_OptionalGroup { + if this != nil { + return this.Optionalgroup + } + return nil + } + + type Test_OptionalGroup struct { + RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` + XXX_unrecognized []byte `json:"-"` + } + func (this *Test_OptionalGroup) Reset() { *this = Test_OptionalGroup{} } + func (this *Test_OptionalGroup) String() string { return proto.CompactTextString(this) } + + func (this *Test_OptionalGroup) GetRequiredField() string { + if this != nil && this.RequiredField != nil { + return *this.RequiredField + } + return "" + } + + func init() { + proto.RegisterEnum("example.FOO", FOO_name, FOO_value) + } + + To create and play with a Test object: + + package main + + import ( + "log" + + "code.google.com/p/goprotobuf/proto" + "./example.pb" + ) + + func main() { + test := &example.Test{ + Label: proto.String("hello"), + Type: proto.Int32(17), + Optionalgroup: &example.Test_OptionalGroup{ + RequiredField: proto.String("good bye"), + }, + } + data, err := proto.Marshal(test) + if err != nil { + log.Fatal("marshaling error: ", err) + } + newTest := new(example.Test) + err = proto.Unmarshal(data, newTest) + if err != nil { + log.Fatal("unmarshaling error: ", err) + } + // Now test and newTest contain the same data. + if test.GetLabel() != newTest.GetLabel() { + log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel()) + } + // etc. + } +*/ +package proto + +import ( + "encoding/json" + "fmt" + "log" + "reflect" + "strconv" + "sync" +) + +// Message is implemented by generated protocol buffer messages. +type Message interface { + Reset() + String() string + ProtoMessage() +} + +// Stats records allocation details about the protocol buffer encoders +// and decoders. Useful for tuning the library itself. +type Stats struct { + Emalloc uint64 // mallocs in encode + Dmalloc uint64 // mallocs in decode + Encode uint64 // number of encodes + Decode uint64 // number of decodes + Chit uint64 // number of cache hits + Cmiss uint64 // number of cache misses + Size uint64 // number of sizes +} + +// Set to true to enable stats collection. +const collectStats = false + +var stats Stats + +// GetStats returns a copy of the global Stats structure. +func GetStats() Stats { return stats } + +// A Buffer is a buffer manager for marshaling and unmarshaling +// protocol buffers. It may be reused between invocations to +// reduce memory usage. It is not necessary to use a Buffer; +// the global functions Marshal and Unmarshal create a +// temporary Buffer and are fine for most applications. +type Buffer struct { + buf []byte // encode/decode byte stream + index int // write point + + // pools of basic types to amortize allocation. + bools []bool + uint32s []uint32 + uint64s []uint64 + + // extra pools, only used with pointer_reflect.go + int32s []int32 + int64s []int64 + float32s []float32 + float64s []float64 +} + +// NewBuffer allocates a new Buffer and initializes its internal data to +// the contents of the argument slice. +func NewBuffer(e []byte) *Buffer { + return &Buffer{buf: e} +} + +// Reset resets the Buffer, ready for marshaling a new protocol buffer. +func (p *Buffer) Reset() { + p.buf = p.buf[0:0] // for reading/writing + p.index = 0 // for reading +} + +// SetBuf replaces the internal buffer with the slice, +// ready for unmarshaling the contents of the slice. +func (p *Buffer) SetBuf(s []byte) { + p.buf = s + p.index = 0 +} + +// Bytes returns the contents of the Buffer. +func (p *Buffer) Bytes() []byte { return p.buf } + +/* + * Helper routines for simplifying the creation of optional fields of basic type. + */ + +// Bool is a helper routine that allocates a new bool value +// to store v and returns a pointer to it. +func Bool(v bool) *bool { + return &v +} + +// Int32 is a helper routine that allocates a new int32 value +// to store v and returns a pointer to it. +func Int32(v int32) *int32 { + return &v +} + +// Int is a helper routine that allocates a new int32 value +// to store v and returns a pointer to it, but unlike Int32 +// its argument value is an int. +func Int(v int) *int32 { + p := new(int32) + *p = int32(v) + return p +} + +// Int64 is a helper routine that allocates a new int64 value +// to store v and returns a pointer to it. +func Int64(v int64) *int64 { + return &v +} + +// Float32 is a helper routine that allocates a new float32 value +// to store v and returns a pointer to it. +func Float32(v float32) *float32 { + return &v +} + +// Float64 is a helper routine that allocates a new float64 value +// to store v and returns a pointer to it. +func Float64(v float64) *float64 { + return &v +} + +// Uint32 is a helper routine that allocates a new uint32 value +// to store v and returns a pointer to it. +func Uint32(v uint32) *uint32 { + p := new(uint32) + *p = v + return p +} + +// Uint64 is a helper routine that allocates a new uint64 value +// to store v and returns a pointer to it. +func Uint64(v uint64) *uint64 { + return &v +} + +// String is a helper routine that allocates a new string value +// to store v and returns a pointer to it. +func String(v string) *string { + return &v +} + +// EnumName is a helper function to simplify printing protocol buffer enums +// by name. Given an enum map and a value, it returns a useful string. +func EnumName(m map[int32]string, v int32) string { + s, ok := m[v] + if ok { + return s + } + return strconv.Itoa(int(v)) +} + +// UnmarshalJSONEnum is a helper function to simplify recovering enum int values +// from their JSON-encoded representation. Given a map from the enum's symbolic +// names to its int values, and a byte buffer containing the JSON-encoded +// value, it returns an int32 that can be cast to the enum type by the caller. +// +// The function can deal with both JSON representations, numeric and symbolic. +func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32, error) { + if data[0] == '"' { + // New style: enums are strings. + var repr string + if err := json.Unmarshal(data, &repr); err != nil { + return -1, err + } + val, ok := m[repr] + if !ok { + return 0, fmt.Errorf("unrecognized enum %s value %q", enumName, repr) + } + return val, nil + } + // Old style: enums are ints. + var val int32 + if err := json.Unmarshal(data, &val); err != nil { + return 0, fmt.Errorf("cannot unmarshal %#q into enum %s", data, enumName) + } + return val, nil +} + +// DebugPrint dumps the encoded data in b in a debugging format with a header +// including the string s. Used in testing but made available for general debugging. +func (o *Buffer) DebugPrint(s string, b []byte) { + var u uint64 + + obuf := o.buf + index := o.index + o.buf = b + o.index = 0 + depth := 0 + + fmt.Printf("\n--- %s ---\n", s) + +out: + for { + for i := 0; i < depth; i++ { + fmt.Print(" ") + } + + index := o.index + if index == len(o.buf) { + break + } + + op, err := o.DecodeVarint() + if err != nil { + fmt.Printf("%3d: fetching op err %v\n", index, err) + break out + } + tag := op >> 3 + wire := op & 7 + + switch wire { + default: + fmt.Printf("%3d: t=%3d unknown wire=%d\n", + index, tag, wire) + break out + + case WireBytes: + var r []byte + + r, err = o.DecodeRawBytes(false) + if err != nil { + break out + } + fmt.Printf("%3d: t=%3d bytes [%d]", index, tag, len(r)) + if len(r) <= 6 { + for i := 0; i < len(r); i++ { + fmt.Printf(" %.2x", r[i]) + } + } else { + for i := 0; i < 3; i++ { + fmt.Printf(" %.2x", r[i]) + } + fmt.Printf(" ..") + for i := len(r) - 3; i < len(r); i++ { + fmt.Printf(" %.2x", r[i]) + } + } + fmt.Printf("\n") + + case WireFixed32: + u, err = o.DecodeFixed32() + if err != nil { + fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u) + + case WireFixed64: + u, err = o.DecodeFixed64() + if err != nil { + fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u) + break + + case WireVarint: + u, err = o.DecodeVarint() + if err != nil { + fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u) + + case WireStartGroup: + if err != nil { + fmt.Printf("%3d: t=%3d start err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d start\n", index, tag) + depth++ + + case WireEndGroup: + depth-- + if err != nil { + fmt.Printf("%3d: t=%3d end err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d end\n", index, tag) + } + } + + if depth != 0 { + fmt.Printf("%3d: start-end not balanced %d\n", o.index, depth) + } + fmt.Printf("\n") + + o.buf = obuf + o.index = index +} + +// SetDefaults sets unset protocol buffer fields to their default values. +// It only modifies fields that are both unset and have defined defaults. +// It recursively sets default values in any non-nil sub-messages. +func SetDefaults(pb Message) { + setDefaults(reflect.ValueOf(pb), true, false) +} + +// v is a pointer to a struct. +func setDefaults(v reflect.Value, recur, zeros bool) { + v = v.Elem() + + defaultMu.RLock() + dm, ok := defaults[v.Type()] + defaultMu.RUnlock() + if !ok { + dm = buildDefaultMessage(v.Type()) + defaultMu.Lock() + defaults[v.Type()] = dm + defaultMu.Unlock() + } + + for _, sf := range dm.scalars { + f := v.Field(sf.index) + if !f.IsNil() { + // field already set + continue + } + dv := sf.value + if dv == nil && !zeros { + // no explicit default, and don't want to set zeros + continue + } + fptr := f.Addr().Interface() // **T + // TODO: Consider batching the allocations we do here. + switch sf.kind { + case reflect.Bool: + b := new(bool) + if dv != nil { + *b = dv.(bool) + } + *(fptr.(**bool)) = b + case reflect.Float32: + f := new(float32) + if dv != nil { + *f = dv.(float32) + } + *(fptr.(**float32)) = f + case reflect.Float64: + f := new(float64) + if dv != nil { + *f = dv.(float64) + } + *(fptr.(**float64)) = f + case reflect.Int32: + // might be an enum + if ft := f.Type(); ft != int32PtrType { + // enum + f.Set(reflect.New(ft.Elem())) + if dv != nil { + f.Elem().SetInt(int64(dv.(int32))) + } + } else { + // int32 field + i := new(int32) + if dv != nil { + *i = dv.(int32) + } + *(fptr.(**int32)) = i + } + case reflect.Int64: + i := new(int64) + if dv != nil { + *i = dv.(int64) + } + *(fptr.(**int64)) = i + case reflect.String: + s := new(string) + if dv != nil { + *s = dv.(string) + } + *(fptr.(**string)) = s + case reflect.Uint8: + // exceptional case: []byte + var b []byte + if dv != nil { + db := dv.([]byte) + b = make([]byte, len(db)) + copy(b, db) + } else { + b = []byte{} + } + *(fptr.(*[]byte)) = b + case reflect.Uint32: + u := new(uint32) + if dv != nil { + *u = dv.(uint32) + } + *(fptr.(**uint32)) = u + case reflect.Uint64: + u := new(uint64) + if dv != nil { + *u = dv.(uint64) + } + *(fptr.(**uint64)) = u + default: + log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind) + } + } + + for _, ni := range dm.nested { + f := v.Field(ni) + if f.IsNil() { + continue + } + // f is *T or []*T + if f.Kind() == reflect.Ptr { + setDefaults(f, recur, zeros) + } else { + for i := 0; i < f.Len(); i++ { + e := f.Index(i) + if e.IsNil() { + continue + } + setDefaults(e, recur, zeros) + } + } + } +} + +var ( + // defaults maps a protocol buffer struct type to a slice of the fields, + // with its scalar fields set to their proto-declared non-zero default values. + defaultMu sync.RWMutex + defaults = make(map[reflect.Type]defaultMessage) + + int32PtrType = reflect.TypeOf((*int32)(nil)) +) + +// defaultMessage represents information about the default values of a message. +type defaultMessage struct { + scalars []scalarField + nested []int // struct field index of nested messages +} + +type scalarField struct { + index int // struct field index + kind reflect.Kind // element type (the T in *T or []T) + value interface{} // the proto-declared default value, or nil +} + +func ptrToStruct(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +// t is a struct type. +func buildDefaultMessage(t reflect.Type) (dm defaultMessage) { + sprop := GetProperties(t) + for _, prop := range sprop.Prop { + fi, ok := sprop.decoderTags.get(prop.Tag) + if !ok { + // XXX_unrecognized + continue + } + ft := t.Field(fi).Type + + // nested messages + if ptrToStruct(ft) || (ft.Kind() == reflect.Slice && ptrToStruct(ft.Elem())) { + dm.nested = append(dm.nested, fi) + continue + } + + sf := scalarField{ + index: fi, + kind: ft.Elem().Kind(), + } + + // scalar fields without defaults + if !prop.HasDefault { + dm.scalars = append(dm.scalars, sf) + continue + } + + // a scalar field: either *T or []byte + switch ft.Elem().Kind() { + case reflect.Bool: + x, err := strconv.ParseBool(prop.Default) + if err != nil { + log.Printf("proto: bad default bool %q: %v", prop.Default, err) + continue + } + sf.value = x + case reflect.Float32: + x, err := strconv.ParseFloat(prop.Default, 32) + if err != nil { + log.Printf("proto: bad default float32 %q: %v", prop.Default, err) + continue + } + sf.value = float32(x) + case reflect.Float64: + x, err := strconv.ParseFloat(prop.Default, 64) + if err != nil { + log.Printf("proto: bad default float64 %q: %v", prop.Default, err) + continue + } + sf.value = x + case reflect.Int32: + x, err := strconv.ParseInt(prop.Default, 10, 32) + if err != nil { + log.Printf("proto: bad default int32 %q: %v", prop.Default, err) + continue + } + sf.value = int32(x) + case reflect.Int64: + x, err := strconv.ParseInt(prop.Default, 10, 64) + if err != nil { + log.Printf("proto: bad default int64 %q: %v", prop.Default, err) + continue + } + sf.value = x + case reflect.String: + sf.value = prop.Default + case reflect.Uint8: + // []byte (not *uint8) + sf.value = []byte(prop.Default) + case reflect.Uint32: + x, err := strconv.ParseUint(prop.Default, 10, 32) + if err != nil { + log.Printf("proto: bad default uint32 %q: %v", prop.Default, err) + continue + } + sf.value = uint32(x) + case reflect.Uint64: + x, err := strconv.ParseUint(prop.Default, 10, 64) + if err != nil { + log.Printf("proto: bad default uint64 %q: %v", prop.Default, err) + continue + } + sf.value = x + default: + log.Printf("proto: unhandled def kind %v", ft.Elem().Kind()) + continue + } + + dm.scalars = append(dm.scalars, sf) + } + + return dm +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set.go new file mode 100644 index 0000000000..1a17809830 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set.go @@ -0,0 +1,229 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Support for message sets. + */ + +import ( + "errors" + "reflect" + "sort" +) + +// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID. +// A message type ID is required for storing a protocol buffer in a message set. +var ErrNoMessageTypeId = errors.New("proto does not have a message type ID") + +// The first two types (_MessageSet_Item and MessageSet) +// model what the protocol compiler produces for the following protocol message: +// message MessageSet { +// repeated group Item = 1 { +// required int32 type_id = 2; +// required string message = 3; +// }; +// } +// That is the MessageSet wire format. We can't use a proto to generate these +// because that would introduce a circular dependency between it and this package. +// +// When a proto1 proto has a field that looks like: +// optional message info = 3; +// the protocol compiler produces a field in the generated struct that looks like: +// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"` +// The package is automatically inserted so there is no need for that proto file to +// import this package. + +type _MessageSet_Item struct { + TypeId *int32 `protobuf:"varint,2,req,name=type_id"` + Message []byte `protobuf:"bytes,3,req,name=message"` +} + +type MessageSet struct { + Item []*_MessageSet_Item `protobuf:"group,1,rep"` + XXX_unrecognized []byte + // TODO: caching? +} + +// Make sure MessageSet is a Message. +var _ Message = (*MessageSet)(nil) + +// messageTypeIder is an interface satisfied by a protocol buffer type +// that may be stored in a MessageSet. +type messageTypeIder interface { + MessageTypeId() int32 +} + +func (ms *MessageSet) find(pb Message) *_MessageSet_Item { + mti, ok := pb.(messageTypeIder) + if !ok { + return nil + } + id := mti.MessageTypeId() + for _, item := range ms.Item { + if *item.TypeId == id { + return item + } + } + return nil +} + +func (ms *MessageSet) Has(pb Message) bool { + if ms.find(pb) != nil { + return true + } + return false +} + +func (ms *MessageSet) Unmarshal(pb Message) error { + if item := ms.find(pb); item != nil { + return Unmarshal(item.Message, pb) + } + if _, ok := pb.(messageTypeIder); !ok { + return ErrNoMessageTypeId + } + return nil // TODO: return error instead? +} + +func (ms *MessageSet) Marshal(pb Message) error { + msg, err := Marshal(pb) + if err != nil { + return err + } + if item := ms.find(pb); item != nil { + // reuse existing item + item.Message = msg + return nil + } + + mti, ok := pb.(messageTypeIder) + if !ok { + return ErrNoMessageTypeId + } + + mtid := mti.MessageTypeId() + ms.Item = append(ms.Item, &_MessageSet_Item{ + TypeId: &mtid, + Message: msg, + }) + return nil +} + +func (ms *MessageSet) Reset() { *ms = MessageSet{} } +func (ms *MessageSet) String() string { return CompactTextString(ms) } +func (*MessageSet) ProtoMessage() {} + +// Support for the message_set_wire_format message option. + +func skipVarint(buf []byte) []byte { + i := 0 + for ; buf[i]&0x80 != 0; i++ { + } + return buf[i+1:] +} + +// MarshalMessageSet encodes the extension map represented by m in the message set wire format. +// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option. +func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { + if err := encodeExtensionMap(m); err != nil { + return nil, err + } + + // Sort extension IDs to provide a deterministic encoding. + // See also enc_map in encode.go. + ids := make([]int, 0, len(m)) + for id := range m { + ids = append(ids, int(id)) + } + sort.Ints(ids) + + ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))} + for _, id := range ids { + e := m[int32(id)] + // Remove the wire type and field number varint, as well as the length varint. + msg := skipVarint(skipVarint(e.enc)) + + ms.Item = append(ms.Item, &_MessageSet_Item{ + TypeId: Int32(int32(id)), + Message: msg, + }) + } + return Marshal(ms) +} + +// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. +// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option. +func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { + ms := new(MessageSet) + if err := Unmarshal(buf, ms); err != nil { + return err + } + for _, item := range ms.Item { + id := *item.TypeId + msg := item.Message + + // Restore wire type and field number varint, plus length varint. + // Be careful to preserve duplicate items. + b := EncodeVarint(uint64(id)<<3 | WireBytes) + if ext, ok := m[id]; ok { + // Existing data; rip off the tag and length varint + // so we join the new data correctly. + // We can assume that ext.enc is set because we are unmarshaling. + o := ext.enc[len(b):] // skip wire type and field number + _, n := DecodeVarint(o) // calculate length of length varint + o = o[n:] // skip length varint + msg = append(o, msg...) // join old data and new data + } + b = append(b, EncodeVarint(uint64(len(msg)))...) + b = append(b, msg...) + + m[id] = Extension{enc: b} + } + return nil +} + +// A global registry of types that can be used in a MessageSet. + +var messageSetMap = make(map[int32]messageSetDesc) + +type messageSetDesc struct { + t reflect.Type // pointer to struct + name string +} + +// RegisterMessageSetType is called from the generated code. +func RegisterMessageSetType(i messageTypeIder, name string) { + messageSetMap[i.MessageTypeId()] = messageSetDesc{ + t: reflect.TypeOf(i), + name: name, + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set_test.go new file mode 100644 index 0000000000..bb311bcccf --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/message_set_test.go @@ -0,0 +1,66 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2014 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +import ( + "bytes" + "testing" +) + +func TestUnmarshalMessageSetWithDuplicate(t *testing.T) { + // Check that a repeated message set entry will be concatenated. + in := &MessageSet{ + Item: []*_MessageSet_Item{ + {TypeId: Int32(12345), Message: []byte("hoo")}, + {TypeId: Int32(12345), Message: []byte("hah")}, + }, + } + b, err := Marshal(in) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + t.Logf("Marshaled bytes: %q", b) + + m := make(map[int32]Extension) + if err := UnmarshalMessageSet(b, m); err != nil { + t.Fatalf("UnmarshalMessageSet: %v", err) + } + ext, ok := m[12345] + if !ok { + t.Fatalf("Didn't retrieve extension 12345; map is %v", m) + } + // Skip wire type/field number and length varints. + got := skipVarint(skipVarint(ext.enc)) + if want := []byte("hoohah"); !bytes.Equal(got, want) { + t.Errorf("Combined extension is %q, want %q", got, want) + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_reflect.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_reflect.go new file mode 100644 index 0000000000..61141ba858 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_reflect.go @@ -0,0 +1,384 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +build appengine,!appenginevm + +// This file contains an implementation of proto field accesses using package reflect. +// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can +// be used on App Engine. + +package proto + +import ( + "math" + "reflect" +) + +// A structPointer is a pointer to a struct. +type structPointer struct { + v reflect.Value +} + +// toStructPointer returns a structPointer equivalent to the given reflect value. +// The reflect value must itself be a pointer to a struct. +func toStructPointer(v reflect.Value) structPointer { + return structPointer{v} +} + +// IsNil reports whether p is nil. +func structPointer_IsNil(p structPointer) bool { + return p.v.IsNil() +} + +// Interface returns the struct pointer as an interface value. +func structPointer_Interface(p structPointer, _ reflect.Type) interface{} { + return p.v.Interface() +} + +// A field identifies a field in a struct, accessible from a structPointer. +// In this implementation, a field is identified by the sequence of field indices +// passed to reflect's FieldByIndex. +type field []int + +// toField returns a field equivalent to the given reflect field. +func toField(f *reflect.StructField) field { + return f.Index +} + +// invalidField is an invalid field identifier. +var invalidField = field(nil) + +// IsValid reports whether the field identifier is valid. +func (f field) IsValid() bool { return f != nil } + +// field returns the given field in the struct as a reflect value. +func structPointer_field(p structPointer, f field) reflect.Value { + // Special case: an extension map entry with a value of type T + // passes a *T to the struct-handling code with a zero field, + // expecting that it will be treated as equivalent to *struct{ X T }, + // which has the same memory layout. We have to handle that case + // specially, because reflect will panic if we call FieldByIndex on a + // non-struct. + if f == nil { + return p.v.Elem() + } + + return p.v.Elem().FieldByIndex(f) +} + +// ifield returns the given field in the struct as an interface value. +func structPointer_ifield(p structPointer, f field) interface{} { + return structPointer_field(p, f).Addr().Interface() +} + +// Bytes returns the address of a []byte field in the struct. +func structPointer_Bytes(p structPointer, f field) *[]byte { + return structPointer_ifield(p, f).(*[]byte) +} + +// BytesSlice returns the address of a [][]byte field in the struct. +func structPointer_BytesSlice(p structPointer, f field) *[][]byte { + return structPointer_ifield(p, f).(*[][]byte) +} + +// Bool returns the address of a *bool field in the struct. +func structPointer_Bool(p structPointer, f field) **bool { + return structPointer_ifield(p, f).(**bool) +} + +// BoolSlice returns the address of a []bool field in the struct. +func structPointer_BoolSlice(p structPointer, f field) *[]bool { + return structPointer_ifield(p, f).(*[]bool) +} + +// String returns the address of a *string field in the struct. +func structPointer_String(p structPointer, f field) **string { + return structPointer_ifield(p, f).(**string) +} + +// StringSlice returns the address of a []string field in the struct. +func structPointer_StringSlice(p structPointer, f field) *[]string { + return structPointer_ifield(p, f).(*[]string) +} + +// ExtMap returns the address of an extension map field in the struct. +func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { + return structPointer_ifield(p, f).(*map[int32]Extension) +} + +// SetStructPointer writes a *struct field in the struct. +func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { + structPointer_field(p, f).Set(q.v) +} + +// GetStructPointer reads a *struct field in the struct. +func structPointer_GetStructPointer(p structPointer, f field) structPointer { + return structPointer{structPointer_field(p, f)} +} + +// StructPointerSlice the address of a []*struct field in the struct. +func structPointer_StructPointerSlice(p structPointer, f field) structPointerSlice { + return structPointerSlice{structPointer_field(p, f)} +} + +// A structPointerSlice represents the address of a slice of pointers to structs +// (themselves messages or groups). That is, v.Type() is *[]*struct{...}. +type structPointerSlice struct { + v reflect.Value +} + +func (p structPointerSlice) Len() int { return p.v.Len() } +func (p structPointerSlice) Index(i int) structPointer { return structPointer{p.v.Index(i)} } +func (p structPointerSlice) Append(q structPointer) { + p.v.Set(reflect.Append(p.v, q.v)) +} + +var ( + int32Type = reflect.TypeOf(int32(0)) + uint32Type = reflect.TypeOf(uint32(0)) + float32Type = reflect.TypeOf(float32(0)) + int64Type = reflect.TypeOf(int64(0)) + uint64Type = reflect.TypeOf(uint64(0)) + float64Type = reflect.TypeOf(float64(0)) +) + +// A word32 represents a field of type *int32, *uint32, *float32, or *enum. +// That is, v.Type() is *int32, *uint32, *float32, or *enum and v is assignable. +type word32 struct { + v reflect.Value +} + +// IsNil reports whether p is nil. +func word32_IsNil(p word32) bool { + return p.v.IsNil() +} + +// Set sets p to point at a newly allocated word with bits set to x. +func word32_Set(p word32, o *Buffer, x uint32) { + t := p.v.Type().Elem() + switch t { + case int32Type: + if len(o.int32s) == 0 { + o.int32s = make([]int32, uint32PoolSize) + } + o.int32s[0] = int32(x) + p.v.Set(reflect.ValueOf(&o.int32s[0])) + o.int32s = o.int32s[1:] + return + case uint32Type: + if len(o.uint32s) == 0 { + o.uint32s = make([]uint32, uint32PoolSize) + } + o.uint32s[0] = x + p.v.Set(reflect.ValueOf(&o.uint32s[0])) + o.uint32s = o.uint32s[1:] + return + case float32Type: + if len(o.float32s) == 0 { + o.float32s = make([]float32, uint32PoolSize) + } + o.float32s[0] = math.Float32frombits(x) + p.v.Set(reflect.ValueOf(&o.float32s[0])) + o.float32s = o.float32s[1:] + return + } + + // must be enum + p.v.Set(reflect.New(t)) + p.v.Elem().SetInt(int64(int32(x))) +} + +// Get gets the bits pointed at by p, as a uint32. +func word32_Get(p word32) uint32 { + elem := p.v.Elem() + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32 returns a reference to a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32(p structPointer, f field) word32 { + return word32{structPointer_field(p, f)} +} + +// A word32Slice is a slice of 32-bit values. +// That is, v.Type() is []int32, []uint32, []float32, or []enum. +type word32Slice struct { + v reflect.Value +} + +func (p word32Slice) Append(x uint32) { + n, m := p.v.Len(), p.v.Cap() + if n < m { + p.v.SetLen(n + 1) + } else { + t := p.v.Type().Elem() + p.v.Set(reflect.Append(p.v, reflect.Zero(t))) + } + elem := p.v.Index(n) + switch elem.Kind() { + case reflect.Int32: + elem.SetInt(int64(int32(x))) + case reflect.Uint32: + elem.SetUint(uint64(x)) + case reflect.Float32: + elem.SetFloat(float64(math.Float32frombits(x))) + } +} + +func (p word32Slice) Len() int { + return p.v.Len() +} + +func (p word32Slice) Index(i int) uint32 { + elem := p.v.Index(i) + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32Slice returns a reference to a []int32, []uint32, []float32, or []enum field in the struct. +func structPointer_Word32Slice(p structPointer, f field) word32Slice { + return word32Slice{structPointer_field(p, f)} +} + +// word64 is like word32 but for 64-bit values. +type word64 struct { + v reflect.Value +} + +func word64_Set(p word64, o *Buffer, x uint64) { + t := p.v.Type().Elem() + switch t { + case int64Type: + if len(o.int64s) == 0 { + o.int64s = make([]int64, uint64PoolSize) + } + o.int64s[0] = int64(x) + p.v.Set(reflect.ValueOf(&o.int64s[0])) + o.int64s = o.int64s[1:] + return + case uint64Type: + if len(o.uint64s) == 0 { + o.uint64s = make([]uint64, uint64PoolSize) + } + o.uint64s[0] = x + p.v.Set(reflect.ValueOf(&o.uint64s[0])) + o.uint64s = o.uint64s[1:] + return + case float64Type: + if len(o.float64s) == 0 { + o.float64s = make([]float64, uint64PoolSize) + } + o.float64s[0] = math.Float64frombits(x) + p.v.Set(reflect.ValueOf(&o.float64s[0])) + o.float64s = o.float64s[1:] + return + } + panic("unreachable") +} + +func word64_IsNil(p word64) bool { + return p.v.IsNil() +} + +func word64_Get(p word64) uint64 { + elem := p.v.Elem() + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return elem.Uint() + case reflect.Float64: + return math.Float64bits(elem.Float()) + } + panic("unreachable") +} + +func structPointer_Word64(p structPointer, f field) word64 { + return word64{structPointer_field(p, f)} +} + +type word64Slice struct { + v reflect.Value +} + +func (p word64Slice) Append(x uint64) { + n, m := p.v.Len(), p.v.Cap() + if n < m { + p.v.SetLen(n + 1) + } else { + t := p.v.Type().Elem() + p.v.Set(reflect.Append(p.v, reflect.Zero(t))) + } + elem := p.v.Index(n) + switch elem.Kind() { + case reflect.Int64: + elem.SetInt(int64(int64(x))) + case reflect.Uint64: + elem.SetUint(uint64(x)) + case reflect.Float64: + elem.SetFloat(float64(math.Float64frombits(x))) + } +} + +func (p word64Slice) Len() int { + return p.v.Len() +} + +func (p word64Slice) Index(i int) uint64 { + elem := p.v.Index(i) + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return uint64(elem.Uint()) + case reflect.Float64: + return math.Float64bits(float64(elem.Float())) + } + panic("unreachable") +} + +func structPointer_Word64Slice(p structPointer, f field) word64Slice { + return word64Slice{structPointer_field(p, f)} +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_unsafe.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_unsafe.go new file mode 100644 index 0000000000..27a536c88d --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/pointer_unsafe.go @@ -0,0 +1,218 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +build !appengine appenginevm + +// This file contains the implementation of the proto field accesses using package unsafe. + +package proto + +import ( + "reflect" + "unsafe" +) + +// NOTE: These type_Foo functions would more idiomatically be methods, +// but Go does not allow methods on pointer types, and we must preserve +// some pointer type for the garbage collector. We use these +// funcs with clunky names as our poor approximation to methods. +// +// An alternative would be +// type structPointer struct { p unsafe.Pointer } +// but that does not registerize as well. + +// A structPointer is a pointer to a struct. +type structPointer unsafe.Pointer + +// toStructPointer returns a structPointer equivalent to the given reflect value. +func toStructPointer(v reflect.Value) structPointer { + return structPointer(unsafe.Pointer(v.Pointer())) +} + +// IsNil reports whether p is nil. +func structPointer_IsNil(p structPointer) bool { + return p == nil +} + +// Interface returns the struct pointer, assumed to have element type t, +// as an interface value. +func structPointer_Interface(p structPointer, t reflect.Type) interface{} { + return reflect.NewAt(t, unsafe.Pointer(p)).Interface() +} + +// A field identifies a field in a struct, accessible from a structPointer. +// In this implementation, a field is identified by its byte offset from the start of the struct. +type field uintptr + +// toField returns a field equivalent to the given reflect field. +func toField(f *reflect.StructField) field { + return field(f.Offset) +} + +// invalidField is an invalid field identifier. +const invalidField = ^field(0) + +// IsValid reports whether the field identifier is valid. +func (f field) IsValid() bool { + return f != ^field(0) +} + +// Bytes returns the address of a []byte field in the struct. +func structPointer_Bytes(p structPointer, f field) *[]byte { + return (*[]byte)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BytesSlice returns the address of a [][]byte field in the struct. +func structPointer_BytesSlice(p structPointer, f field) *[][]byte { + return (*[][]byte)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// Bool returns the address of a *bool field in the struct. +func structPointer_Bool(p structPointer, f field) **bool { + return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BoolSlice returns the address of a []bool field in the struct. +func structPointer_BoolSlice(p structPointer, f field) *[]bool { + return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// String returns the address of a *string field in the struct. +func structPointer_String(p structPointer, f field) **string { + return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StringSlice returns the address of a []string field in the struct. +func structPointer_StringSlice(p structPointer, f field) *[]string { + return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// ExtMap returns the address of an extension map field in the struct. +func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { + return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// SetStructPointer writes a *struct field in the struct. +func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { + *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q +} + +// GetStructPointer reads a *struct field in the struct. +func structPointer_GetStructPointer(p structPointer, f field) structPointer { + return *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StructPointerSlice the address of a []*struct field in the struct. +func structPointer_StructPointerSlice(p structPointer, f field) *structPointerSlice { + return (*structPointerSlice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// A structPointerSlice represents a slice of pointers to structs (themselves submessages or groups). +type structPointerSlice []structPointer + +func (v *structPointerSlice) Len() int { return len(*v) } +func (v *structPointerSlice) Index(i int) structPointer { return (*v)[i] } +func (v *structPointerSlice) Append(p structPointer) { *v = append(*v, p) } + +// A word32 is the address of a "pointer to 32-bit value" field. +type word32 **uint32 + +// IsNil reports whether *v is nil. +func word32_IsNil(p word32) bool { + return *p == nil +} + +// Set sets *v to point at a newly allocated word set to x. +func word32_Set(p word32, o *Buffer, x uint32) { + if len(o.uint32s) == 0 { + o.uint32s = make([]uint32, uint32PoolSize) + } + o.uint32s[0] = x + *p = &o.uint32s[0] + o.uint32s = o.uint32s[1:] +} + +// Get gets the value pointed at by *v. +func word32_Get(p word32) uint32 { + return **p +} + +// Word32 returns the address of a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32(p structPointer, f field) word32 { + return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// A word32Slice is a slice of 32-bit values. +type word32Slice []uint32 + +func (v *word32Slice) Append(x uint32) { *v = append(*v, x) } +func (v *word32Slice) Len() int { return len(*v) } +func (v *word32Slice) Index(i int) uint32 { return (*v)[i] } + +// Word32Slice returns the address of a []int32, []uint32, []float32, or []enum field in the struct. +func structPointer_Word32Slice(p structPointer, f field) *word32Slice { + return (*word32Slice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// word64 is like word32 but for 64-bit values. +type word64 **uint64 + +func word64_Set(p word64, o *Buffer, x uint64) { + if len(o.uint64s) == 0 { + o.uint64s = make([]uint64, uint64PoolSize) + } + o.uint64s[0] = x + *p = &o.uint64s[0] + o.uint64s = o.uint64s[1:] +} + +func word64_IsNil(p word64) bool { + return *p == nil +} + +func word64_Get(p word64) uint64 { + return **p +} + +func structPointer_Word64(p structPointer, f field) word64 { + return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// word64Slice is like word32Slice but for 64-bit values. +type word64Slice []uint64 + +func (v *word64Slice) Append(x uint64) { *v = append(*v, x) } +func (v *word64Slice) Len() int { return len(*v) } +func (v *word64Slice) Index(i int) uint64 { return (*v)[i] } + +func structPointer_Word64Slice(p structPointer, f field) *word64Slice { + return (*word64Slice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/properties.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/properties.go new file mode 100644 index 0000000000..f2c1e861cd --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/properties.go @@ -0,0 +1,658 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for encoding data into the wire format for protocol buffers. + */ + +import ( + "fmt" + "os" + "reflect" + "sort" + "strconv" + "strings" + "sync" +) + +const debug bool = false + +// Constants that identify the encoding of a value on the wire. +const ( + WireVarint = 0 + WireFixed64 = 1 + WireBytes = 2 + WireStartGroup = 3 + WireEndGroup = 4 + WireFixed32 = 5 +) + +const startSize = 10 // initial slice/string sizes + +// Encoders are defined in encode.go +// An encoder outputs the full representation of a field, including its +// tag and encoder type. +type encoder func(p *Buffer, prop *Properties, base structPointer) error + +// A valueEncoder encodes a single integer in a particular encoding. +type valueEncoder func(o *Buffer, x uint64) error + +// Sizers are defined in encode.go +// A sizer returns the encoded size of a field, including its tag and encoder +// type. +type sizer func(prop *Properties, base structPointer) int + +// A valueSizer returns the encoded size of a single integer in a particular +// encoding. +type valueSizer func(x uint64) int + +// Decoders are defined in decode.go +// A decoder creates a value from its wire representation. +// Unrecognized subelements are saved in unrec. +type decoder func(p *Buffer, prop *Properties, base structPointer) error + +// A valueDecoder decodes a single integer in a particular encoding. +type valueDecoder func(o *Buffer) (x uint64, err error) + +// tagMap is an optimization over map[int]int for typical protocol buffer +// use-cases. Encoded protocol buffers are often in tag order with small tag +// numbers. +type tagMap struct { + fastTags []int + slowTags map[int]int +} + +// tagMapFastLimit is the upper bound on the tag number that will be stored in +// the tagMap slice rather than its map. +const tagMapFastLimit = 1024 + +func (p *tagMap) get(t int) (int, bool) { + if t > 0 && t < tagMapFastLimit { + if t >= len(p.fastTags) { + return 0, false + } + fi := p.fastTags[t] + return fi, fi >= 0 + } + fi, ok := p.slowTags[t] + return fi, ok +} + +func (p *tagMap) put(t int, fi int) { + if t > 0 && t < tagMapFastLimit { + for len(p.fastTags) < t+1 { + p.fastTags = append(p.fastTags, -1) + } + p.fastTags[t] = fi + return + } + if p.slowTags == nil { + p.slowTags = make(map[int]int) + } + p.slowTags[t] = fi +} + +// StructProperties represents properties for all the fields of a struct. +// decoderTags and decoderOrigNames should only be used by the decoder. +type StructProperties struct { + Prop []*Properties // properties for each field + reqCount int // required count + decoderTags tagMap // map from proto tag to struct field number + decoderOrigNames map[string]int // map from original name to struct field number + order []int // list of struct field numbers in tag order + unrecField field // field id of the XXX_unrecognized []byte field + extendable bool // is this an extendable proto +} + +// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec. +// See encode.go, (*Buffer).enc_struct. + +func (sp *StructProperties) Len() int { return len(sp.order) } +func (sp *StructProperties) Less(i, j int) bool { + return sp.Prop[sp.order[i]].Tag < sp.Prop[sp.order[j]].Tag +} +func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order[j], sp.order[i] } + +// Properties represents the protocol-specific behavior of a single struct field. +type Properties struct { + Name string // name of the field, for error messages + OrigName string // original name before protocol compiler (always set) + Wire string + WireType int + Tag int + Required bool + Optional bool + Repeated bool + Packed bool // relevant for repeated primitives only + Enum string // set for enum types only + + Default string // default value + HasDefault bool // whether an explicit default was provided + def_uint64 uint64 + + enc encoder + valEnc valueEncoder // set for bool and numeric types only + field field + tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType) + tagbuf [8]byte + stype reflect.Type // set for struct types only + sprop *StructProperties // set for struct types only + isMarshaler bool + isUnmarshaler bool + + size sizer + valSize valueSizer // set for bool and numeric types only + + dec decoder + valDec valueDecoder // set for bool and numeric types only + + // If this is a packable field, this will be the decoder for the packed version of the field. + packedDec decoder +} + +// String formats the properties in the protobuf struct field tag style. +func (p *Properties) String() string { + s := p.Wire + s = "," + s += strconv.Itoa(p.Tag) + if p.Required { + s += ",req" + } + if p.Optional { + s += ",opt" + } + if p.Repeated { + s += ",rep" + } + if p.Packed { + s += ",packed" + } + if p.OrigName != p.Name { + s += ",name=" + p.OrigName + } + if len(p.Enum) > 0 { + s += ",enum=" + p.Enum + } + if p.HasDefault { + s += ",def=" + p.Default + } + return s +} + +// Parse populates p by parsing a string in the protobuf struct field tag style. +func (p *Properties) Parse(s string) { + // "bytes,49,opt,name=foo,def=hello!" + fields := strings.Split(s, ",") // breaks def=, but handled below. + if len(fields) < 2 { + fmt.Fprintf(os.Stderr, "proto: tag has too few fields: %q\n", s) + return + } + + p.Wire = fields[0] + switch p.Wire { + case "varint": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeVarint + p.valDec = (*Buffer).DecodeVarint + p.valSize = sizeVarint + case "fixed32": + p.WireType = WireFixed32 + p.valEnc = (*Buffer).EncodeFixed32 + p.valDec = (*Buffer).DecodeFixed32 + p.valSize = sizeFixed32 + case "fixed64": + p.WireType = WireFixed64 + p.valEnc = (*Buffer).EncodeFixed64 + p.valDec = (*Buffer).DecodeFixed64 + p.valSize = sizeFixed64 + case "zigzag32": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeZigzag32 + p.valDec = (*Buffer).DecodeZigzag32 + p.valSize = sizeZigzag32 + case "zigzag64": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeZigzag64 + p.valDec = (*Buffer).DecodeZigzag64 + p.valSize = sizeZigzag64 + case "bytes", "group": + p.WireType = WireBytes + // no numeric converter for non-numeric types + default: + fmt.Fprintf(os.Stderr, "proto: tag has unknown wire type: %q\n", s) + return + } + + var err error + p.Tag, err = strconv.Atoi(fields[1]) + if err != nil { + return + } + + for i := 2; i < len(fields); i++ { + f := fields[i] + switch { + case f == "req": + p.Required = true + case f == "opt": + p.Optional = true + case f == "rep": + p.Repeated = true + case f == "packed": + p.Packed = true + case strings.HasPrefix(f, "name="): + p.OrigName = f[5:] + case strings.HasPrefix(f, "enum="): + p.Enum = f[5:] + case strings.HasPrefix(f, "def="): + p.HasDefault = true + p.Default = f[4:] // rest of string + if i+1 < len(fields) { + // Commas aren't escaped, and def is always last. + p.Default += "," + strings.Join(fields[i+1:], ",") + break + } + } + } +} + +func logNoSliceEnc(t1, t2 reflect.Type) { + fmt.Fprintf(os.Stderr, "proto: no slice oenc for %T = []%T\n", t1, t2) +} + +var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem() + +// Initialize the fields for encoding and decoding. +func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) { + p.enc = nil + p.dec = nil + p.size = nil + + switch t1 := typ; t1.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no coders for %T\n", t1) + + case reflect.Ptr: + switch t2 := t1.Elem(); t2.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no encoder function for %T -> %T\n", t1, t2) + break + case reflect.Bool: + p.enc = (*Buffer).enc_bool + p.dec = (*Buffer).dec_bool + p.size = size_bool + case reflect.Int32: + p.enc = (*Buffer).enc_int32 + p.dec = (*Buffer).dec_int32 + p.size = size_int32 + case reflect.Uint32: + p.enc = (*Buffer).enc_uint32 + p.dec = (*Buffer).dec_int32 // can reuse + p.size = size_uint32 + case reflect.Int64, reflect.Uint64: + p.enc = (*Buffer).enc_int64 + p.dec = (*Buffer).dec_int64 + p.size = size_int64 + case reflect.Float32: + p.enc = (*Buffer).enc_uint32 // can just treat them as bits + p.dec = (*Buffer).dec_int32 + p.size = size_uint32 + case reflect.Float64: + p.enc = (*Buffer).enc_int64 // can just treat them as bits + p.dec = (*Buffer).dec_int64 + p.size = size_int64 + case reflect.String: + p.enc = (*Buffer).enc_string + p.dec = (*Buffer).dec_string + p.size = size_string + case reflect.Struct: + p.stype = t1.Elem() + p.isMarshaler = isMarshaler(t1) + p.isUnmarshaler = isUnmarshaler(t1) + if p.Wire == "bytes" { + p.enc = (*Buffer).enc_struct_message + p.dec = (*Buffer).dec_struct_message + p.size = size_struct_message + } else { + p.enc = (*Buffer).enc_struct_group + p.dec = (*Buffer).dec_struct_group + p.size = size_struct_group + } + } + + case reflect.Slice: + switch t2 := t1.Elem(); t2.Kind() { + default: + logNoSliceEnc(t1, t2) + break + case reflect.Bool: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_bool + p.size = size_slice_packed_bool + } else { + p.enc = (*Buffer).enc_slice_bool + p.size = size_slice_bool + } + p.dec = (*Buffer).dec_slice_bool + p.packedDec = (*Buffer).dec_slice_packed_bool + case reflect.Int32: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int32 + p.size = size_slice_packed_int32 + } else { + p.enc = (*Buffer).enc_slice_int32 + p.size = size_slice_int32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case reflect.Uint32: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_uint32 + p.size = size_slice_packed_uint32 + } else { + p.enc = (*Buffer).enc_slice_uint32 + p.size = size_slice_uint32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case reflect.Int64, reflect.Uint64: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int64 + p.size = size_slice_packed_int64 + } else { + p.enc = (*Buffer).enc_slice_int64 + p.size = size_slice_int64 + } + p.dec = (*Buffer).dec_slice_int64 + p.packedDec = (*Buffer).dec_slice_packed_int64 + case reflect.Uint8: + p.enc = (*Buffer).enc_slice_byte + p.dec = (*Buffer).dec_slice_byte + p.size = size_slice_byte + case reflect.Float32, reflect.Float64: + switch t2.Bits() { + case 32: + // can just treat them as bits + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_uint32 + p.size = size_slice_packed_uint32 + } else { + p.enc = (*Buffer).enc_slice_uint32 + p.size = size_slice_uint32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case 64: + // can just treat them as bits + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int64 + p.size = size_slice_packed_int64 + } else { + p.enc = (*Buffer).enc_slice_int64 + p.size = size_slice_int64 + } + p.dec = (*Buffer).dec_slice_int64 + p.packedDec = (*Buffer).dec_slice_packed_int64 + default: + logNoSliceEnc(t1, t2) + break + } + case reflect.String: + p.enc = (*Buffer).enc_slice_string + p.dec = (*Buffer).dec_slice_string + p.size = size_slice_string + case reflect.Ptr: + switch t3 := t2.Elem(); t3.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no ptr oenc for %T -> %T -> %T\n", t1, t2, t3) + break + case reflect.Struct: + p.stype = t2.Elem() + p.isMarshaler = isMarshaler(t2) + p.isUnmarshaler = isUnmarshaler(t2) + if p.Wire == "bytes" { + p.enc = (*Buffer).enc_slice_struct_message + p.dec = (*Buffer).dec_slice_struct_message + p.size = size_slice_struct_message + } else { + p.enc = (*Buffer).enc_slice_struct_group + p.dec = (*Buffer).dec_slice_struct_group + p.size = size_slice_struct_group + } + } + case reflect.Slice: + switch t2.Elem().Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no slice elem oenc for %T -> %T -> %T\n", t1, t2, t2.Elem()) + break + case reflect.Uint8: + p.enc = (*Buffer).enc_slice_slice_byte + p.dec = (*Buffer).dec_slice_slice_byte + p.size = size_slice_slice_byte + } + } + } + + // precalculate tag code + wire := p.WireType + if p.Packed { + wire = WireBytes + } + x := uint32(p.Tag)<<3 | uint32(wire) + i := 0 + for i = 0; x > 127; i++ { + p.tagbuf[i] = 0x80 | uint8(x&0x7F) + x >>= 7 + } + p.tagbuf[i] = uint8(x) + p.tagcode = p.tagbuf[0 : i+1] + + if p.stype != nil { + if lockGetProp { + p.sprop = GetProperties(p.stype) + } else { + p.sprop = getPropertiesLocked(p.stype) + } + } +} + +var ( + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() +) + +// isMarshaler reports whether type t implements Marshaler. +func isMarshaler(t reflect.Type) bool { + // We're checking for (likely) pointer-receiver methods + // so if t is not a pointer, something is very wrong. + // The calls above only invoke isMarshaler on pointer types. + if t.Kind() != reflect.Ptr { + panic("proto: misuse of isMarshaler") + } + return t.Implements(marshalerType) +} + +// isUnmarshaler reports whether type t implements Unmarshaler. +func isUnmarshaler(t reflect.Type) bool { + // We're checking for (likely) pointer-receiver methods + // so if t is not a pointer, something is very wrong. + // The calls above only invoke isUnmarshaler on pointer types. + if t.Kind() != reflect.Ptr { + panic("proto: misuse of isUnmarshaler") + } + return t.Implements(unmarshalerType) +} + +// Init populates the properties from a protocol buffer struct tag. +func (p *Properties) Init(typ reflect.Type, name, tag string, f *reflect.StructField) { + p.init(typ, name, tag, f, true) +} + +func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructField, lockGetProp bool) { + // "bytes,49,opt,def=hello!" + p.Name = name + p.OrigName = name + if f != nil { + p.field = toField(f) + } + if tag == "" { + return + } + p.Parse(tag) + p.setEncAndDec(typ, lockGetProp) +} + +var ( + mutex sync.Mutex + propertiesMap = make(map[reflect.Type]*StructProperties) +) + +// GetProperties returns the list of properties for the type represented by t. +func GetProperties(t reflect.Type) *StructProperties { + mutex.Lock() + sprop := getPropertiesLocked(t) + mutex.Unlock() + return sprop +} + +// getPropertiesLocked requires that mutex is held. +func getPropertiesLocked(t reflect.Type) *StructProperties { + if prop, ok := propertiesMap[t]; ok { + if collectStats { + stats.Chit++ + } + return prop + } + if collectStats { + stats.Cmiss++ + } + + prop := new(StructProperties) + // in case of recursive protos, fill this in now. + propertiesMap[t] = prop + + // build properties + prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) + prop.unrecField = invalidField + prop.Prop = make([]*Properties, t.NumField()) + prop.order = make([]int, t.NumField()) + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + p := new(Properties) + name := f.Name + p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false) + + if f.Name == "XXX_extensions" { // special case + p.enc = (*Buffer).enc_map + p.dec = nil // not needed + p.size = size_map + } + if f.Name == "XXX_unrecognized" { // special case + prop.unrecField = toField(&f) + } + prop.Prop[i] = p + prop.order[i] = i + if debug { + print(i, " ", f.Name, " ", t.String(), " ") + if p.Tag > 0 { + print(p.String()) + } + print("\n") + } + if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") { + fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]") + } + } + + // Re-order prop.order. + sort.Sort(prop) + + // build required counts + // build tags + reqCount := 0 + prop.decoderOrigNames = make(map[string]int) + for i, p := range prop.Prop { + if strings.HasPrefix(p.Name, "XXX_") { + // Internal fields should not appear in tags/origNames maps. + // They are handled specially when encoding and decoding. + continue + } + if p.Required { + reqCount++ + } + prop.decoderTags.put(p.Tag, i) + prop.decoderOrigNames[p.OrigName] = i + } + prop.reqCount = reqCount + + return prop +} + +// Return the Properties object for the x[0]'th field of the structure. +func propByIndex(t reflect.Type, x []int) *Properties { + if len(x) != 1 { + fmt.Fprintf(os.Stderr, "proto: field index dimension %d (not 1) for type %s\n", len(x), t) + return nil + } + prop := GetProperties(t) + return prop.Prop[x[0]] +} + +// Get the address and type of a pointer to a struct from an interface. +func getbase(pb Message) (t reflect.Type, b structPointer, err error) { + if pb == nil { + err = ErrNil + return + } + // get the reflect type of the pointer to the struct. + t = reflect.TypeOf(pb) + // get the address of the struct. + value := reflect.ValueOf(pb) + b = toStructPointer(value) + return +} + +// A global registry of enum types. +// The generated code will register the generated maps by calling RegisterEnum. + +var enumValueMaps = make(map[string]map[string]int32) + +// RegisterEnum is called from the generated code to install the enum descriptor +// maps into the global table to aid parsing text format protocol buffers. +func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[string]int32) { + if _, ok := enumValueMaps[typeName]; ok { + panic("proto: duplicate enum registered: " + typeName) + } + enumValueMaps[typeName] = valueMap +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size2_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size2_test.go new file mode 100644 index 0000000000..55902a4a9c --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size2_test.go @@ -0,0 +1,63 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +import ( + "testing" +) + +// This is a separate file and package from size_test.go because that one uses +// generated messages and thus may not be in package proto without having a circular +// dependency, whereas this file tests unexported details of size.go. + +func TestVarintSize(t *testing.T) { + // Check the edge cases carefully. + testCases := []struct { + n uint64 + size int + }{ + {0, 1}, + {1, 1}, + {127, 1}, + {128, 2}, + {16383, 2}, + {16384, 3}, + {1<<63 - 1, 9}, + {1 << 63, 10}, + } + for _, tc := range testCases { + size := sizeVarint(tc.n) + if size != tc.size { + t.Errorf("sizeVarint(%d) = %d, want %d", tc.n, size, tc.size) + } + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size_test.go new file mode 100644 index 0000000000..6cb9de6d1d --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/size_test.go @@ -0,0 +1,120 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "log" + "testing" + + pb "./testdata" + . "code.google.com/p/goprotobuf/proto" +) + +var messageWithExtension1 = &pb.MyMessage{Count: Int32(7)} + +// messageWithExtension2 is in equal_test.go. +var messageWithExtension3 = &pb.MyMessage{Count: Int32(8)} + +func init() { + if err := SetExtension(messageWithExtension1, pb.E_Ext_More, &pb.Ext{Data: String("Abbott")}); err != nil { + log.Panicf("SetExtension: %v", err) + } + if err := SetExtension(messageWithExtension3, pb.E_Ext_More, &pb.Ext{Data: String("Costello")}); err != nil { + log.Panicf("SetExtension: %v", err) + } + + // Force messageWithExtension3 to have the extension encoded. + Marshal(messageWithExtension3) + +} + +var SizeTests = []struct { + desc string + pb Message +}{ + {"empty", &pb.OtherMessage{}}, + // Basic types. + {"bool", &pb.Defaults{F_Bool: Bool(true)}}, + {"int32", &pb.Defaults{F_Int32: Int32(12)}}, + {"negative int32", &pb.Defaults{F_Int32: Int32(-1)}}, + {"small int64", &pb.Defaults{F_Int64: Int64(1)}}, + {"big int64", &pb.Defaults{F_Int64: Int64(1 << 20)}}, + {"negative int64", &pb.Defaults{F_Int64: Int64(-1)}}, + {"fixed32", &pb.Defaults{F_Fixed32: Uint32(71)}}, + {"fixed64", &pb.Defaults{F_Fixed64: Uint64(72)}}, + {"uint32", &pb.Defaults{F_Uint32: Uint32(123)}}, + {"uint64", &pb.Defaults{F_Uint64: Uint64(124)}}, + {"float", &pb.Defaults{F_Float: Float32(12.6)}}, + {"double", &pb.Defaults{F_Double: Float64(13.9)}}, + {"string", &pb.Defaults{F_String: String("niles")}}, + {"bytes", &pb.Defaults{F_Bytes: []byte("wowsa")}}, + {"bytes, empty", &pb.Defaults{F_Bytes: []byte{}}}, + {"sint32", &pb.Defaults{F_Sint32: Int32(65)}}, + {"sint64", &pb.Defaults{F_Sint64: Int64(67)}}, + {"enum", &pb.Defaults{F_Enum: pb.Defaults_BLUE.Enum()}}, + // Repeated. + {"empty repeated bool", &pb.MoreRepeated{Bools: []bool{}}}, + {"repeated bool", &pb.MoreRepeated{Bools: []bool{false, true, true, false}}}, + {"packed repeated bool", &pb.MoreRepeated{BoolsPacked: []bool{false, true, true, false, true, true, true}}}, + {"repeated int32", &pb.MoreRepeated{Ints: []int32{1, 12203, 1729, -1}}}, + {"repeated int32 packed", &pb.MoreRepeated{IntsPacked: []int32{1, 12203, 1729}}}, + {"repeated int64 packed", &pb.MoreRepeated{Int64SPacked: []int64{ + // Need enough large numbers to verify that the header is counting the number of bytes + // for the field, not the number of elements. + 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, + 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, + }}}, + {"repeated string", &pb.MoreRepeated{Strings: []string{"r", "ken", "gri"}}}, + {"repeated fixed", &pb.MoreRepeated{Fixeds: []uint32{1, 2, 3, 4}}}, + // Nested. + {"nested", &pb.OldMessage{Nested: &pb.OldMessage_Nested{Name: String("whatever")}}}, + {"group", &pb.GroupOld{G: &pb.GroupOld_G{X: Int32(12345)}}}, + // Other things. + {"unrecognized", &pb.MoreRepeated{XXX_unrecognized: []byte{13<<3 | 0, 4}}}, + {"extension (unencoded)", messageWithExtension1}, + {"extension (encoded)", messageWithExtension3}, +} + +func TestSize(t *testing.T) { + for _, tc := range SizeTests { + size := Size(tc.pb) + b, err := Marshal(tc.pb) + if err != nil { + t.Errorf("%v: Marshal failed: %v", tc.desc, err) + continue + } + if size != len(b) { + t.Errorf("%v: Size(%v) = %d, want %d", tc.desc, tc.pb, size, len(b)) + t.Logf("%v: bytes: %#v", tc.desc, b) + } + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/Makefile b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/Makefile new file mode 100644 index 0000000000..9fa10e4c69 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/Makefile @@ -0,0 +1,50 @@ +# Go support for Protocol Buffers - Google's data interchange format +# +# Copyright 2010 The Go Authors. All rights reserved. +# http://code.google.com/p/goprotobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +include ../../Make.protobuf + +all: regenerate + +regenerate: + rm -f test.pb.go + make test.pb.go + +# The following rules are just aids to development. Not needed for typical testing. + +diff: regenerate + hg diff test.pb.go + +restore: + cp test.pb.go.golden test.pb.go + +preserve: + cp test.pb.go test.pb.go.golden diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/golden_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/golden_test.go new file mode 100644 index 0000000000..f614aa1b73 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/golden_test.go @@ -0,0 +1,86 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Verify that the compiler output for test.proto is unchanged. + +package testdata + +import ( + "crypto/sha1" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "testing" +) + +// sum returns in string form (for easy comparison) the SHA-1 hash of the named file. +func sum(t *testing.T, name string) string { + data, err := ioutil.ReadFile(name) + if err != nil { + t.Fatal(err) + } + t.Logf("sum(%q): length is %d", name, len(data)) + hash := sha1.New() + _, err = hash.Write(data) + if err != nil { + t.Fatal(err) + } + return fmt.Sprintf("% x", hash.Sum(nil)) +} + +func run(t *testing.T, name string, args ...string) { + cmd := exec.Command(name, args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Run() + if err != nil { + t.Fatal(err) + } +} + +func TestGolden(t *testing.T) { + // Compute the original checksum. + goldenSum := sum(t, "test.pb.go") + // Run the proto compiler. + run(t, "protoc", "--go_out="+os.TempDir(), "test.proto") + newFile := filepath.Join(os.TempDir(), "test.pb.go") + defer os.Remove(newFile) + // Compute the new checksum. + newSum := sum(t, newFile) + // Verify + if newSum != goldenSum { + run(t, "diff", "-u", "test.pb.go", newFile) + t.Fatal("Code generated by protoc-gen-go has changed; update test.pb.go") + } +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.pb.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.pb.go new file mode 100644 index 0000000000..5a9a603662 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.pb.go @@ -0,0 +1,2350 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package testdata is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + GoEnum + GoTestField + GoTest + GoSkipTest + NonPackedTest + PackedTest + MaxTag + OldMessage + NewMessage + InnerMessage + OtherMessage + MyMessage + Ext + MyMessageSet + Empty + MessageList + Strings + Defaults + SubDefaults + RepeatedEnum + MoreRepeated + GroupOld + GroupNew + FloatingPoint +*/ +package testdata + +import proto "code.google.com/p/goprotobuf/proto" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = math.Inf + +type FOO int32 + +const ( + FOO_FOO1 FOO = 1 +) + +var FOO_name = map[int32]string{ + 1: "FOO1", +} +var FOO_value = map[string]int32{ + "FOO1": 1, +} + +func (x FOO) Enum() *FOO { + p := new(FOO) + *p = x + return p +} +func (x FOO) String() string { + return proto.EnumName(FOO_name, int32(x)) +} +func (x *FOO) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO") + if err != nil { + return err + } + *x = FOO(value) + return nil +} + +// An enum, for completeness. +type GoTest_KIND int32 + +const ( + GoTest_VOID GoTest_KIND = 0 + // Basic types + GoTest_BOOL GoTest_KIND = 1 + GoTest_BYTES GoTest_KIND = 2 + GoTest_FINGERPRINT GoTest_KIND = 3 + GoTest_FLOAT GoTest_KIND = 4 + GoTest_INT GoTest_KIND = 5 + GoTest_STRING GoTest_KIND = 6 + GoTest_TIME GoTest_KIND = 7 + // Groupings + GoTest_TUPLE GoTest_KIND = 8 + GoTest_ARRAY GoTest_KIND = 9 + GoTest_MAP GoTest_KIND = 10 + // Table types + GoTest_TABLE GoTest_KIND = 11 + // Functions + GoTest_FUNCTION GoTest_KIND = 12 +) + +var GoTest_KIND_name = map[int32]string{ + 0: "VOID", + 1: "BOOL", + 2: "BYTES", + 3: "FINGERPRINT", + 4: "FLOAT", + 5: "INT", + 6: "STRING", + 7: "TIME", + 8: "TUPLE", + 9: "ARRAY", + 10: "MAP", + 11: "TABLE", + 12: "FUNCTION", +} +var GoTest_KIND_value = map[string]int32{ + "VOID": 0, + "BOOL": 1, + "BYTES": 2, + "FINGERPRINT": 3, + "FLOAT": 4, + "INT": 5, + "STRING": 6, + "TIME": 7, + "TUPLE": 8, + "ARRAY": 9, + "MAP": 10, + "TABLE": 11, + "FUNCTION": 12, +} + +func (x GoTest_KIND) Enum() *GoTest_KIND { + p := new(GoTest_KIND) + *p = x + return p +} +func (x GoTest_KIND) String() string { + return proto.EnumName(GoTest_KIND_name, int32(x)) +} +func (x *GoTest_KIND) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(GoTest_KIND_value, data, "GoTest_KIND") + if err != nil { + return err + } + *x = GoTest_KIND(value) + return nil +} + +type MyMessage_Color int32 + +const ( + MyMessage_RED MyMessage_Color = 0 + MyMessage_GREEN MyMessage_Color = 1 + MyMessage_BLUE MyMessage_Color = 2 +) + +var MyMessage_Color_name = map[int32]string{ + 0: "RED", + 1: "GREEN", + 2: "BLUE", +} +var MyMessage_Color_value = map[string]int32{ + "RED": 0, + "GREEN": 1, + "BLUE": 2, +} + +func (x MyMessage_Color) Enum() *MyMessage_Color { + p := new(MyMessage_Color) + *p = x + return p +} +func (x MyMessage_Color) String() string { + return proto.EnumName(MyMessage_Color_name, int32(x)) +} +func (x *MyMessage_Color) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(MyMessage_Color_value, data, "MyMessage_Color") + if err != nil { + return err + } + *x = MyMessage_Color(value) + return nil +} + +type Defaults_Color int32 + +const ( + Defaults_RED Defaults_Color = 0 + Defaults_GREEN Defaults_Color = 1 + Defaults_BLUE Defaults_Color = 2 +) + +var Defaults_Color_name = map[int32]string{ + 0: "RED", + 1: "GREEN", + 2: "BLUE", +} +var Defaults_Color_value = map[string]int32{ + "RED": 0, + "GREEN": 1, + "BLUE": 2, +} + +func (x Defaults_Color) Enum() *Defaults_Color { + p := new(Defaults_Color) + *p = x + return p +} +func (x Defaults_Color) String() string { + return proto.EnumName(Defaults_Color_name, int32(x)) +} +func (x *Defaults_Color) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(Defaults_Color_value, data, "Defaults_Color") + if err != nil { + return err + } + *x = Defaults_Color(value) + return nil +} + +type RepeatedEnum_Color int32 + +const ( + RepeatedEnum_RED RepeatedEnum_Color = 1 +) + +var RepeatedEnum_Color_name = map[int32]string{ + 1: "RED", +} +var RepeatedEnum_Color_value = map[string]int32{ + "RED": 1, +} + +func (x RepeatedEnum_Color) Enum() *RepeatedEnum_Color { + p := new(RepeatedEnum_Color) + *p = x + return p +} +func (x RepeatedEnum_Color) String() string { + return proto.EnumName(RepeatedEnum_Color_name, int32(x)) +} +func (x *RepeatedEnum_Color) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(RepeatedEnum_Color_value, data, "RepeatedEnum_Color") + if err != nil { + return err + } + *x = RepeatedEnum_Color(value) + return nil +} + +type GoEnum struct { + Foo *FOO `protobuf:"varint,1,req,name=foo,enum=testdata.FOO" json:"foo,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoEnum) Reset() { *m = GoEnum{} } +func (m *GoEnum) String() string { return proto.CompactTextString(m) } +func (*GoEnum) ProtoMessage() {} + +func (m *GoEnum) GetFoo() FOO { + if m != nil && m.Foo != nil { + return *m.Foo + } + return FOO_FOO1 +} + +type GoTestField struct { + Label *string `protobuf:"bytes,1,req" json:"Label,omitempty"` + Type *string `protobuf:"bytes,2,req" json:"Type,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoTestField) Reset() { *m = GoTestField{} } +func (m *GoTestField) String() string { return proto.CompactTextString(m) } +func (*GoTestField) ProtoMessage() {} + +func (m *GoTestField) GetLabel() string { + if m != nil && m.Label != nil { + return *m.Label + } + return "" +} + +func (m *GoTestField) GetType() string { + if m != nil && m.Type != nil { + return *m.Type + } + return "" +} + +type GoTest struct { + // Some typical parameters + Kind *GoTest_KIND `protobuf:"varint,1,req,enum=testdata.GoTest_KIND" json:"Kind,omitempty"` + Table *string `protobuf:"bytes,2,opt" json:"Table,omitempty"` + Param *int32 `protobuf:"varint,3,opt" json:"Param,omitempty"` + // Required, repeated and optional foreign fields. + RequiredField *GoTestField `protobuf:"bytes,4,req" json:"RequiredField,omitempty"` + RepeatedField []*GoTestField `protobuf:"bytes,5,rep" json:"RepeatedField,omitempty"` + OptionalField *GoTestField `protobuf:"bytes,6,opt" json:"OptionalField,omitempty"` + // Required fields of all basic types + F_BoolRequired *bool `protobuf:"varint,10,req,name=F_Bool_required" json:"F_Bool_required,omitempty"` + F_Int32Required *int32 `protobuf:"varint,11,req,name=F_Int32_required" json:"F_Int32_required,omitempty"` + F_Int64Required *int64 `protobuf:"varint,12,req,name=F_Int64_required" json:"F_Int64_required,omitempty"` + F_Fixed32Required *uint32 `protobuf:"fixed32,13,req,name=F_Fixed32_required" json:"F_Fixed32_required,omitempty"` + F_Fixed64Required *uint64 `protobuf:"fixed64,14,req,name=F_Fixed64_required" json:"F_Fixed64_required,omitempty"` + F_Uint32Required *uint32 `protobuf:"varint,15,req,name=F_Uint32_required" json:"F_Uint32_required,omitempty"` + F_Uint64Required *uint64 `protobuf:"varint,16,req,name=F_Uint64_required" json:"F_Uint64_required,omitempty"` + F_FloatRequired *float32 `protobuf:"fixed32,17,req,name=F_Float_required" json:"F_Float_required,omitempty"` + F_DoubleRequired *float64 `protobuf:"fixed64,18,req,name=F_Double_required" json:"F_Double_required,omitempty"` + F_StringRequired *string `protobuf:"bytes,19,req,name=F_String_required" json:"F_String_required,omitempty"` + F_BytesRequired []byte `protobuf:"bytes,101,req,name=F_Bytes_required" json:"F_Bytes_required,omitempty"` + F_Sint32Required *int32 `protobuf:"zigzag32,102,req,name=F_Sint32_required" json:"F_Sint32_required,omitempty"` + F_Sint64Required *int64 `protobuf:"zigzag64,103,req,name=F_Sint64_required" json:"F_Sint64_required,omitempty"` + // Repeated fields of all basic types + F_BoolRepeated []bool `protobuf:"varint,20,rep,name=F_Bool_repeated" json:"F_Bool_repeated,omitempty"` + F_Int32Repeated []int32 `protobuf:"varint,21,rep,name=F_Int32_repeated" json:"F_Int32_repeated,omitempty"` + F_Int64Repeated []int64 `protobuf:"varint,22,rep,name=F_Int64_repeated" json:"F_Int64_repeated,omitempty"` + F_Fixed32Repeated []uint32 `protobuf:"fixed32,23,rep,name=F_Fixed32_repeated" json:"F_Fixed32_repeated,omitempty"` + F_Fixed64Repeated []uint64 `protobuf:"fixed64,24,rep,name=F_Fixed64_repeated" json:"F_Fixed64_repeated,omitempty"` + F_Uint32Repeated []uint32 `protobuf:"varint,25,rep,name=F_Uint32_repeated" json:"F_Uint32_repeated,omitempty"` + F_Uint64Repeated []uint64 `protobuf:"varint,26,rep,name=F_Uint64_repeated" json:"F_Uint64_repeated,omitempty"` + F_FloatRepeated []float32 `protobuf:"fixed32,27,rep,name=F_Float_repeated" json:"F_Float_repeated,omitempty"` + F_DoubleRepeated []float64 `protobuf:"fixed64,28,rep,name=F_Double_repeated" json:"F_Double_repeated,omitempty"` + F_StringRepeated []string `protobuf:"bytes,29,rep,name=F_String_repeated" json:"F_String_repeated,omitempty"` + F_BytesRepeated [][]byte `protobuf:"bytes,201,rep,name=F_Bytes_repeated" json:"F_Bytes_repeated,omitempty"` + F_Sint32Repeated []int32 `protobuf:"zigzag32,202,rep,name=F_Sint32_repeated" json:"F_Sint32_repeated,omitempty"` + F_Sint64Repeated []int64 `protobuf:"zigzag64,203,rep,name=F_Sint64_repeated" json:"F_Sint64_repeated,omitempty"` + // Optional fields of all basic types + F_BoolOptional *bool `protobuf:"varint,30,opt,name=F_Bool_optional" json:"F_Bool_optional,omitempty"` + F_Int32Optional *int32 `protobuf:"varint,31,opt,name=F_Int32_optional" json:"F_Int32_optional,omitempty"` + F_Int64Optional *int64 `protobuf:"varint,32,opt,name=F_Int64_optional" json:"F_Int64_optional,omitempty"` + F_Fixed32Optional *uint32 `protobuf:"fixed32,33,opt,name=F_Fixed32_optional" json:"F_Fixed32_optional,omitempty"` + F_Fixed64Optional *uint64 `protobuf:"fixed64,34,opt,name=F_Fixed64_optional" json:"F_Fixed64_optional,omitempty"` + F_Uint32Optional *uint32 `protobuf:"varint,35,opt,name=F_Uint32_optional" json:"F_Uint32_optional,omitempty"` + F_Uint64Optional *uint64 `protobuf:"varint,36,opt,name=F_Uint64_optional" json:"F_Uint64_optional,omitempty"` + F_FloatOptional *float32 `protobuf:"fixed32,37,opt,name=F_Float_optional" json:"F_Float_optional,omitempty"` + F_DoubleOptional *float64 `protobuf:"fixed64,38,opt,name=F_Double_optional" json:"F_Double_optional,omitempty"` + F_StringOptional *string `protobuf:"bytes,39,opt,name=F_String_optional" json:"F_String_optional,omitempty"` + F_BytesOptional []byte `protobuf:"bytes,301,opt,name=F_Bytes_optional" json:"F_Bytes_optional,omitempty"` + F_Sint32Optional *int32 `protobuf:"zigzag32,302,opt,name=F_Sint32_optional" json:"F_Sint32_optional,omitempty"` + F_Sint64Optional *int64 `protobuf:"zigzag64,303,opt,name=F_Sint64_optional" json:"F_Sint64_optional,omitempty"` + // Default-valued fields of all basic types + F_BoolDefaulted *bool `protobuf:"varint,40,opt,name=F_Bool_defaulted,def=1" json:"F_Bool_defaulted,omitempty"` + F_Int32Defaulted *int32 `protobuf:"varint,41,opt,name=F_Int32_defaulted,def=32" json:"F_Int32_defaulted,omitempty"` + F_Int64Defaulted *int64 `protobuf:"varint,42,opt,name=F_Int64_defaulted,def=64" json:"F_Int64_defaulted,omitempty"` + F_Fixed32Defaulted *uint32 `protobuf:"fixed32,43,opt,name=F_Fixed32_defaulted,def=320" json:"F_Fixed32_defaulted,omitempty"` + F_Fixed64Defaulted *uint64 `protobuf:"fixed64,44,opt,name=F_Fixed64_defaulted,def=640" json:"F_Fixed64_defaulted,omitempty"` + F_Uint32Defaulted *uint32 `protobuf:"varint,45,opt,name=F_Uint32_defaulted,def=3200" json:"F_Uint32_defaulted,omitempty"` + F_Uint64Defaulted *uint64 `protobuf:"varint,46,opt,name=F_Uint64_defaulted,def=6400" json:"F_Uint64_defaulted,omitempty"` + F_FloatDefaulted *float32 `protobuf:"fixed32,47,opt,name=F_Float_defaulted,def=314159" json:"F_Float_defaulted,omitempty"` + F_DoubleDefaulted *float64 `protobuf:"fixed64,48,opt,name=F_Double_defaulted,def=271828" json:"F_Double_defaulted,omitempty"` + F_StringDefaulted *string `protobuf:"bytes,49,opt,name=F_String_defaulted,def=hello, \"world!\"\n" json:"F_String_defaulted,omitempty"` + F_BytesDefaulted []byte `protobuf:"bytes,401,opt,name=F_Bytes_defaulted,def=Bignose" json:"F_Bytes_defaulted,omitempty"` + F_Sint32Defaulted *int32 `protobuf:"zigzag32,402,opt,name=F_Sint32_defaulted,def=-32" json:"F_Sint32_defaulted,omitempty"` + F_Sint64Defaulted *int64 `protobuf:"zigzag64,403,opt,name=F_Sint64_defaulted,def=-64" json:"F_Sint64_defaulted,omitempty"` + // Packed repeated fields (no string or bytes). + F_BoolRepeatedPacked []bool `protobuf:"varint,50,rep,packed,name=F_Bool_repeated_packed" json:"F_Bool_repeated_packed,omitempty"` + F_Int32RepeatedPacked []int32 `protobuf:"varint,51,rep,packed,name=F_Int32_repeated_packed" json:"F_Int32_repeated_packed,omitempty"` + F_Int64RepeatedPacked []int64 `protobuf:"varint,52,rep,packed,name=F_Int64_repeated_packed" json:"F_Int64_repeated_packed,omitempty"` + F_Fixed32RepeatedPacked []uint32 `protobuf:"fixed32,53,rep,packed,name=F_Fixed32_repeated_packed" json:"F_Fixed32_repeated_packed,omitempty"` + F_Fixed64RepeatedPacked []uint64 `protobuf:"fixed64,54,rep,packed,name=F_Fixed64_repeated_packed" json:"F_Fixed64_repeated_packed,omitempty"` + F_Uint32RepeatedPacked []uint32 `protobuf:"varint,55,rep,packed,name=F_Uint32_repeated_packed" json:"F_Uint32_repeated_packed,omitempty"` + F_Uint64RepeatedPacked []uint64 `protobuf:"varint,56,rep,packed,name=F_Uint64_repeated_packed" json:"F_Uint64_repeated_packed,omitempty"` + F_FloatRepeatedPacked []float32 `protobuf:"fixed32,57,rep,packed,name=F_Float_repeated_packed" json:"F_Float_repeated_packed,omitempty"` + F_DoubleRepeatedPacked []float64 `protobuf:"fixed64,58,rep,packed,name=F_Double_repeated_packed" json:"F_Double_repeated_packed,omitempty"` + F_Sint32RepeatedPacked []int32 `protobuf:"zigzag32,502,rep,packed,name=F_Sint32_repeated_packed" json:"F_Sint32_repeated_packed,omitempty"` + F_Sint64RepeatedPacked []int64 `protobuf:"zigzag64,503,rep,packed,name=F_Sint64_repeated_packed" json:"F_Sint64_repeated_packed,omitempty"` + Requiredgroup *GoTest_RequiredGroup `protobuf:"group,70,req,name=RequiredGroup" json:"requiredgroup,omitempty"` + Repeatedgroup []*GoTest_RepeatedGroup `protobuf:"group,80,rep,name=RepeatedGroup" json:"repeatedgroup,omitempty"` + Optionalgroup *GoTest_OptionalGroup `protobuf:"group,90,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoTest) Reset() { *m = GoTest{} } +func (m *GoTest) String() string { return proto.CompactTextString(m) } +func (*GoTest) ProtoMessage() {} + +const Default_GoTest_F_BoolDefaulted bool = true +const Default_GoTest_F_Int32Defaulted int32 = 32 +const Default_GoTest_F_Int64Defaulted int64 = 64 +const Default_GoTest_F_Fixed32Defaulted uint32 = 320 +const Default_GoTest_F_Fixed64Defaulted uint64 = 640 +const Default_GoTest_F_Uint32Defaulted uint32 = 3200 +const Default_GoTest_F_Uint64Defaulted uint64 = 6400 +const Default_GoTest_F_FloatDefaulted float32 = 314159 +const Default_GoTest_F_DoubleDefaulted float64 = 271828 +const Default_GoTest_F_StringDefaulted string = "hello, \"world!\"\n" + +var Default_GoTest_F_BytesDefaulted []byte = []byte("Bignose") + +const Default_GoTest_F_Sint32Defaulted int32 = -32 +const Default_GoTest_F_Sint64Defaulted int64 = -64 + +func (m *GoTest) GetKind() GoTest_KIND { + if m != nil && m.Kind != nil { + return *m.Kind + } + return GoTest_VOID +} + +func (m *GoTest) GetTable() string { + if m != nil && m.Table != nil { + return *m.Table + } + return "" +} + +func (m *GoTest) GetParam() int32 { + if m != nil && m.Param != nil { + return *m.Param + } + return 0 +} + +func (m *GoTest) GetRequiredField() *GoTestField { + if m != nil { + return m.RequiredField + } + return nil +} + +func (m *GoTest) GetRepeatedField() []*GoTestField { + if m != nil { + return m.RepeatedField + } + return nil +} + +func (m *GoTest) GetOptionalField() *GoTestField { + if m != nil { + return m.OptionalField + } + return nil +} + +func (m *GoTest) GetF_BoolRequired() bool { + if m != nil && m.F_BoolRequired != nil { + return *m.F_BoolRequired + } + return false +} + +func (m *GoTest) GetF_Int32Required() int32 { + if m != nil && m.F_Int32Required != nil { + return *m.F_Int32Required + } + return 0 +} + +func (m *GoTest) GetF_Int64Required() int64 { + if m != nil && m.F_Int64Required != nil { + return *m.F_Int64Required + } + return 0 +} + +func (m *GoTest) GetF_Fixed32Required() uint32 { + if m != nil && m.F_Fixed32Required != nil { + return *m.F_Fixed32Required + } + return 0 +} + +func (m *GoTest) GetF_Fixed64Required() uint64 { + if m != nil && m.F_Fixed64Required != nil { + return *m.F_Fixed64Required + } + return 0 +} + +func (m *GoTest) GetF_Uint32Required() uint32 { + if m != nil && m.F_Uint32Required != nil { + return *m.F_Uint32Required + } + return 0 +} + +func (m *GoTest) GetF_Uint64Required() uint64 { + if m != nil && m.F_Uint64Required != nil { + return *m.F_Uint64Required + } + return 0 +} + +func (m *GoTest) GetF_FloatRequired() float32 { + if m != nil && m.F_FloatRequired != nil { + return *m.F_FloatRequired + } + return 0 +} + +func (m *GoTest) GetF_DoubleRequired() float64 { + if m != nil && m.F_DoubleRequired != nil { + return *m.F_DoubleRequired + } + return 0 +} + +func (m *GoTest) GetF_StringRequired() string { + if m != nil && m.F_StringRequired != nil { + return *m.F_StringRequired + } + return "" +} + +func (m *GoTest) GetF_BytesRequired() []byte { + if m != nil { + return m.F_BytesRequired + } + return nil +} + +func (m *GoTest) GetF_Sint32Required() int32 { + if m != nil && m.F_Sint32Required != nil { + return *m.F_Sint32Required + } + return 0 +} + +func (m *GoTest) GetF_Sint64Required() int64 { + if m != nil && m.F_Sint64Required != nil { + return *m.F_Sint64Required + } + return 0 +} + +func (m *GoTest) GetF_BoolRepeated() []bool { + if m != nil { + return m.F_BoolRepeated + } + return nil +} + +func (m *GoTest) GetF_Int32Repeated() []int32 { + if m != nil { + return m.F_Int32Repeated + } + return nil +} + +func (m *GoTest) GetF_Int64Repeated() []int64 { + if m != nil { + return m.F_Int64Repeated + } + return nil +} + +func (m *GoTest) GetF_Fixed32Repeated() []uint32 { + if m != nil { + return m.F_Fixed32Repeated + } + return nil +} + +func (m *GoTest) GetF_Fixed64Repeated() []uint64 { + if m != nil { + return m.F_Fixed64Repeated + } + return nil +} + +func (m *GoTest) GetF_Uint32Repeated() []uint32 { + if m != nil { + return m.F_Uint32Repeated + } + return nil +} + +func (m *GoTest) GetF_Uint64Repeated() []uint64 { + if m != nil { + return m.F_Uint64Repeated + } + return nil +} + +func (m *GoTest) GetF_FloatRepeated() []float32 { + if m != nil { + return m.F_FloatRepeated + } + return nil +} + +func (m *GoTest) GetF_DoubleRepeated() []float64 { + if m != nil { + return m.F_DoubleRepeated + } + return nil +} + +func (m *GoTest) GetF_StringRepeated() []string { + if m != nil { + return m.F_StringRepeated + } + return nil +} + +func (m *GoTest) GetF_BytesRepeated() [][]byte { + if m != nil { + return m.F_BytesRepeated + } + return nil +} + +func (m *GoTest) GetF_Sint32Repeated() []int32 { + if m != nil { + return m.F_Sint32Repeated + } + return nil +} + +func (m *GoTest) GetF_Sint64Repeated() []int64 { + if m != nil { + return m.F_Sint64Repeated + } + return nil +} + +func (m *GoTest) GetF_BoolOptional() bool { + if m != nil && m.F_BoolOptional != nil { + return *m.F_BoolOptional + } + return false +} + +func (m *GoTest) GetF_Int32Optional() int32 { + if m != nil && m.F_Int32Optional != nil { + return *m.F_Int32Optional + } + return 0 +} + +func (m *GoTest) GetF_Int64Optional() int64 { + if m != nil && m.F_Int64Optional != nil { + return *m.F_Int64Optional + } + return 0 +} + +func (m *GoTest) GetF_Fixed32Optional() uint32 { + if m != nil && m.F_Fixed32Optional != nil { + return *m.F_Fixed32Optional + } + return 0 +} + +func (m *GoTest) GetF_Fixed64Optional() uint64 { + if m != nil && m.F_Fixed64Optional != nil { + return *m.F_Fixed64Optional + } + return 0 +} + +func (m *GoTest) GetF_Uint32Optional() uint32 { + if m != nil && m.F_Uint32Optional != nil { + return *m.F_Uint32Optional + } + return 0 +} + +func (m *GoTest) GetF_Uint64Optional() uint64 { + if m != nil && m.F_Uint64Optional != nil { + return *m.F_Uint64Optional + } + return 0 +} + +func (m *GoTest) GetF_FloatOptional() float32 { + if m != nil && m.F_FloatOptional != nil { + return *m.F_FloatOptional + } + return 0 +} + +func (m *GoTest) GetF_DoubleOptional() float64 { + if m != nil && m.F_DoubleOptional != nil { + return *m.F_DoubleOptional + } + return 0 +} + +func (m *GoTest) GetF_StringOptional() string { + if m != nil && m.F_StringOptional != nil { + return *m.F_StringOptional + } + return "" +} + +func (m *GoTest) GetF_BytesOptional() []byte { + if m != nil { + return m.F_BytesOptional + } + return nil +} + +func (m *GoTest) GetF_Sint32Optional() int32 { + if m != nil && m.F_Sint32Optional != nil { + return *m.F_Sint32Optional + } + return 0 +} + +func (m *GoTest) GetF_Sint64Optional() int64 { + if m != nil && m.F_Sint64Optional != nil { + return *m.F_Sint64Optional + } + return 0 +} + +func (m *GoTest) GetF_BoolDefaulted() bool { + if m != nil && m.F_BoolDefaulted != nil { + return *m.F_BoolDefaulted + } + return Default_GoTest_F_BoolDefaulted +} + +func (m *GoTest) GetF_Int32Defaulted() int32 { + if m != nil && m.F_Int32Defaulted != nil { + return *m.F_Int32Defaulted + } + return Default_GoTest_F_Int32Defaulted +} + +func (m *GoTest) GetF_Int64Defaulted() int64 { + if m != nil && m.F_Int64Defaulted != nil { + return *m.F_Int64Defaulted + } + return Default_GoTest_F_Int64Defaulted +} + +func (m *GoTest) GetF_Fixed32Defaulted() uint32 { + if m != nil && m.F_Fixed32Defaulted != nil { + return *m.F_Fixed32Defaulted + } + return Default_GoTest_F_Fixed32Defaulted +} + +func (m *GoTest) GetF_Fixed64Defaulted() uint64 { + if m != nil && m.F_Fixed64Defaulted != nil { + return *m.F_Fixed64Defaulted + } + return Default_GoTest_F_Fixed64Defaulted +} + +func (m *GoTest) GetF_Uint32Defaulted() uint32 { + if m != nil && m.F_Uint32Defaulted != nil { + return *m.F_Uint32Defaulted + } + return Default_GoTest_F_Uint32Defaulted +} + +func (m *GoTest) GetF_Uint64Defaulted() uint64 { + if m != nil && m.F_Uint64Defaulted != nil { + return *m.F_Uint64Defaulted + } + return Default_GoTest_F_Uint64Defaulted +} + +func (m *GoTest) GetF_FloatDefaulted() float32 { + if m != nil && m.F_FloatDefaulted != nil { + return *m.F_FloatDefaulted + } + return Default_GoTest_F_FloatDefaulted +} + +func (m *GoTest) GetF_DoubleDefaulted() float64 { + if m != nil && m.F_DoubleDefaulted != nil { + return *m.F_DoubleDefaulted + } + return Default_GoTest_F_DoubleDefaulted +} + +func (m *GoTest) GetF_StringDefaulted() string { + if m != nil && m.F_StringDefaulted != nil { + return *m.F_StringDefaulted + } + return Default_GoTest_F_StringDefaulted +} + +func (m *GoTest) GetF_BytesDefaulted() []byte { + if m != nil && m.F_BytesDefaulted != nil { + return m.F_BytesDefaulted + } + return append([]byte(nil), Default_GoTest_F_BytesDefaulted...) +} + +func (m *GoTest) GetF_Sint32Defaulted() int32 { + if m != nil && m.F_Sint32Defaulted != nil { + return *m.F_Sint32Defaulted + } + return Default_GoTest_F_Sint32Defaulted +} + +func (m *GoTest) GetF_Sint64Defaulted() int64 { + if m != nil && m.F_Sint64Defaulted != nil { + return *m.F_Sint64Defaulted + } + return Default_GoTest_F_Sint64Defaulted +} + +func (m *GoTest) GetF_BoolRepeatedPacked() []bool { + if m != nil { + return m.F_BoolRepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Int32RepeatedPacked() []int32 { + if m != nil { + return m.F_Int32RepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Int64RepeatedPacked() []int64 { + if m != nil { + return m.F_Int64RepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Fixed32RepeatedPacked() []uint32 { + if m != nil { + return m.F_Fixed32RepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Fixed64RepeatedPacked() []uint64 { + if m != nil { + return m.F_Fixed64RepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Uint32RepeatedPacked() []uint32 { + if m != nil { + return m.F_Uint32RepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Uint64RepeatedPacked() []uint64 { + if m != nil { + return m.F_Uint64RepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_FloatRepeatedPacked() []float32 { + if m != nil { + return m.F_FloatRepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_DoubleRepeatedPacked() []float64 { + if m != nil { + return m.F_DoubleRepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Sint32RepeatedPacked() []int32 { + if m != nil { + return m.F_Sint32RepeatedPacked + } + return nil +} + +func (m *GoTest) GetF_Sint64RepeatedPacked() []int64 { + if m != nil { + return m.F_Sint64RepeatedPacked + } + return nil +} + +func (m *GoTest) GetRequiredgroup() *GoTest_RequiredGroup { + if m != nil { + return m.Requiredgroup + } + return nil +} + +func (m *GoTest) GetRepeatedgroup() []*GoTest_RepeatedGroup { + if m != nil { + return m.Repeatedgroup + } + return nil +} + +func (m *GoTest) GetOptionalgroup() *GoTest_OptionalGroup { + if m != nil { + return m.Optionalgroup + } + return nil +} + +// Required, repeated, and optional groups. +type GoTest_RequiredGroup struct { + RequiredField *string `protobuf:"bytes,71,req" json:"RequiredField,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoTest_RequiredGroup) Reset() { *m = GoTest_RequiredGroup{} } +func (m *GoTest_RequiredGroup) String() string { return proto.CompactTextString(m) } +func (*GoTest_RequiredGroup) ProtoMessage() {} + +func (m *GoTest_RequiredGroup) GetRequiredField() string { + if m != nil && m.RequiredField != nil { + return *m.RequiredField + } + return "" +} + +type GoTest_RepeatedGroup struct { + RequiredField *string `protobuf:"bytes,81,req" json:"RequiredField,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoTest_RepeatedGroup) Reset() { *m = GoTest_RepeatedGroup{} } +func (m *GoTest_RepeatedGroup) String() string { return proto.CompactTextString(m) } +func (*GoTest_RepeatedGroup) ProtoMessage() {} + +func (m *GoTest_RepeatedGroup) GetRequiredField() string { + if m != nil && m.RequiredField != nil { + return *m.RequiredField + } + return "" +} + +type GoTest_OptionalGroup struct { + RequiredField *string `protobuf:"bytes,91,req" json:"RequiredField,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoTest_OptionalGroup) Reset() { *m = GoTest_OptionalGroup{} } +func (m *GoTest_OptionalGroup) String() string { return proto.CompactTextString(m) } +func (*GoTest_OptionalGroup) ProtoMessage() {} + +func (m *GoTest_OptionalGroup) GetRequiredField() string { + if m != nil && m.RequiredField != nil { + return *m.RequiredField + } + return "" +} + +// For testing skipping of unrecognized fields. +// Numbers are all big, larger than tag numbers in GoTestField, +// the message used in the corresponding test. +type GoSkipTest struct { + SkipInt32 *int32 `protobuf:"varint,11,req,name=skip_int32" json:"skip_int32,omitempty"` + SkipFixed32 *uint32 `protobuf:"fixed32,12,req,name=skip_fixed32" json:"skip_fixed32,omitempty"` + SkipFixed64 *uint64 `protobuf:"fixed64,13,req,name=skip_fixed64" json:"skip_fixed64,omitempty"` + SkipString *string `protobuf:"bytes,14,req,name=skip_string" json:"skip_string,omitempty"` + Skipgroup *GoSkipTest_SkipGroup `protobuf:"group,15,req,name=SkipGroup" json:"skipgroup,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoSkipTest) Reset() { *m = GoSkipTest{} } +func (m *GoSkipTest) String() string { return proto.CompactTextString(m) } +func (*GoSkipTest) ProtoMessage() {} + +func (m *GoSkipTest) GetSkipInt32() int32 { + if m != nil && m.SkipInt32 != nil { + return *m.SkipInt32 + } + return 0 +} + +func (m *GoSkipTest) GetSkipFixed32() uint32 { + if m != nil && m.SkipFixed32 != nil { + return *m.SkipFixed32 + } + return 0 +} + +func (m *GoSkipTest) GetSkipFixed64() uint64 { + if m != nil && m.SkipFixed64 != nil { + return *m.SkipFixed64 + } + return 0 +} + +func (m *GoSkipTest) GetSkipString() string { + if m != nil && m.SkipString != nil { + return *m.SkipString + } + return "" +} + +func (m *GoSkipTest) GetSkipgroup() *GoSkipTest_SkipGroup { + if m != nil { + return m.Skipgroup + } + return nil +} + +type GoSkipTest_SkipGroup struct { + GroupInt32 *int32 `protobuf:"varint,16,req,name=group_int32" json:"group_int32,omitempty"` + GroupString *string `protobuf:"bytes,17,req,name=group_string" json:"group_string,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GoSkipTest_SkipGroup) Reset() { *m = GoSkipTest_SkipGroup{} } +func (m *GoSkipTest_SkipGroup) String() string { return proto.CompactTextString(m) } +func (*GoSkipTest_SkipGroup) ProtoMessage() {} + +func (m *GoSkipTest_SkipGroup) GetGroupInt32() int32 { + if m != nil && m.GroupInt32 != nil { + return *m.GroupInt32 + } + return 0 +} + +func (m *GoSkipTest_SkipGroup) GetGroupString() string { + if m != nil && m.GroupString != nil { + return *m.GroupString + } + return "" +} + +// For testing packed/non-packed decoder switching. +// A serialized instance of one should be deserializable as the other. +type NonPackedTest struct { + A []int32 `protobuf:"varint,1,rep,name=a" json:"a,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *NonPackedTest) Reset() { *m = NonPackedTest{} } +func (m *NonPackedTest) String() string { return proto.CompactTextString(m) } +func (*NonPackedTest) ProtoMessage() {} + +func (m *NonPackedTest) GetA() []int32 { + if m != nil { + return m.A + } + return nil +} + +type PackedTest struct { + B []int32 `protobuf:"varint,1,rep,packed,name=b" json:"b,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *PackedTest) Reset() { *m = PackedTest{} } +func (m *PackedTest) String() string { return proto.CompactTextString(m) } +func (*PackedTest) ProtoMessage() {} + +func (m *PackedTest) GetB() []int32 { + if m != nil { + return m.B + } + return nil +} + +type MaxTag struct { + // Maximum possible tag number. + LastField *string `protobuf:"bytes,536870911,opt,name=last_field" json:"last_field,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MaxTag) Reset() { *m = MaxTag{} } +func (m *MaxTag) String() string { return proto.CompactTextString(m) } +func (*MaxTag) ProtoMessage() {} + +func (m *MaxTag) GetLastField() string { + if m != nil && m.LastField != nil { + return *m.LastField + } + return "" +} + +type OldMessage struct { + Nested *OldMessage_Nested `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + Num *int32 `protobuf:"varint,2,opt,name=num" json:"num,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *OldMessage) Reset() { *m = OldMessage{} } +func (m *OldMessage) String() string { return proto.CompactTextString(m) } +func (*OldMessage) ProtoMessage() {} + +func (m *OldMessage) GetNested() *OldMessage_Nested { + if m != nil { + return m.Nested + } + return nil +} + +func (m *OldMessage) GetNum() int32 { + if m != nil && m.Num != nil { + return *m.Num + } + return 0 +} + +type OldMessage_Nested struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *OldMessage_Nested) Reset() { *m = OldMessage_Nested{} } +func (m *OldMessage_Nested) String() string { return proto.CompactTextString(m) } +func (*OldMessage_Nested) ProtoMessage() {} + +func (m *OldMessage_Nested) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +// NewMessage is wire compatible with OldMessage; +// imagine it as a future version. +type NewMessage struct { + Nested *NewMessage_Nested `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + // This is an int32 in OldMessage. + Num *int64 `protobuf:"varint,2,opt,name=num" json:"num,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *NewMessage) Reset() { *m = NewMessage{} } +func (m *NewMessage) String() string { return proto.CompactTextString(m) } +func (*NewMessage) ProtoMessage() {} + +func (m *NewMessage) GetNested() *NewMessage_Nested { + if m != nil { + return m.Nested + } + return nil +} + +func (m *NewMessage) GetNum() int64 { + if m != nil && m.Num != nil { + return *m.Num + } + return 0 +} + +type NewMessage_Nested struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + FoodGroup *string `protobuf:"bytes,2,opt,name=food_group" json:"food_group,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *NewMessage_Nested) Reset() { *m = NewMessage_Nested{} } +func (m *NewMessage_Nested) String() string { return proto.CompactTextString(m) } +func (*NewMessage_Nested) ProtoMessage() {} + +func (m *NewMessage_Nested) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *NewMessage_Nested) GetFoodGroup() string { + if m != nil && m.FoodGroup != nil { + return *m.FoodGroup + } + return "" +} + +type InnerMessage struct { + Host *string `protobuf:"bytes,1,req,name=host" json:"host,omitempty"` + Port *int32 `protobuf:"varint,2,opt,name=port,def=4000" json:"port,omitempty"` + Connected *bool `protobuf:"varint,3,opt,name=connected" json:"connected,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *InnerMessage) Reset() { *m = InnerMessage{} } +func (m *InnerMessage) String() string { return proto.CompactTextString(m) } +func (*InnerMessage) ProtoMessage() {} + +const Default_InnerMessage_Port int32 = 4000 + +func (m *InnerMessage) GetHost() string { + if m != nil && m.Host != nil { + return *m.Host + } + return "" +} + +func (m *InnerMessage) GetPort() int32 { + if m != nil && m.Port != nil { + return *m.Port + } + return Default_InnerMessage_Port +} + +func (m *InnerMessage) GetConnected() bool { + if m != nil && m.Connected != nil { + return *m.Connected + } + return false +} + +type OtherMessage struct { + Key *int64 `protobuf:"varint,1,opt,name=key" json:"key,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=value" json:"value,omitempty"` + Weight *float32 `protobuf:"fixed32,3,opt,name=weight" json:"weight,omitempty"` + Inner *InnerMessage `protobuf:"bytes,4,opt,name=inner" json:"inner,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *OtherMessage) Reset() { *m = OtherMessage{} } +func (m *OtherMessage) String() string { return proto.CompactTextString(m) } +func (*OtherMessage) ProtoMessage() {} + +func (m *OtherMessage) GetKey() int64 { + if m != nil && m.Key != nil { + return *m.Key + } + return 0 +} + +func (m *OtherMessage) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +func (m *OtherMessage) GetWeight() float32 { + if m != nil && m.Weight != nil { + return *m.Weight + } + return 0 +} + +func (m *OtherMessage) GetInner() *InnerMessage { + if m != nil { + return m.Inner + } + return nil +} + +type MyMessage struct { + Count *int32 `protobuf:"varint,1,req,name=count" json:"count,omitempty"` + Name *string `protobuf:"bytes,2,opt,name=name" json:"name,omitempty"` + Quote *string `protobuf:"bytes,3,opt,name=quote" json:"quote,omitempty"` + Pet []string `protobuf:"bytes,4,rep,name=pet" json:"pet,omitempty"` + Inner *InnerMessage `protobuf:"bytes,5,opt,name=inner" json:"inner,omitempty"` + Others []*OtherMessage `protobuf:"bytes,6,rep,name=others" json:"others,omitempty"` + RepInner []*InnerMessage `protobuf:"bytes,12,rep,name=rep_inner" json:"rep_inner,omitempty"` + Bikeshed *MyMessage_Color `protobuf:"varint,7,opt,name=bikeshed,enum=testdata.MyMessage_Color" json:"bikeshed,omitempty"` + Somegroup *MyMessage_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"` + // This field becomes [][]byte in the generated code. + RepBytes [][]byte `protobuf:"bytes,10,rep,name=rep_bytes" json:"rep_bytes,omitempty"` + Bigfloat *float64 `protobuf:"fixed64,11,opt,name=bigfloat" json:"bigfloat,omitempty"` + XXX_extensions map[int32]proto.Extension `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MyMessage) Reset() { *m = MyMessage{} } +func (m *MyMessage) String() string { return proto.CompactTextString(m) } +func (*MyMessage) ProtoMessage() {} + +var extRange_MyMessage = []proto.ExtensionRange{ + {100, 536870911}, +} + +func (*MyMessage) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_MyMessage +} +func (m *MyMessage) ExtensionMap() map[int32]proto.Extension { + if m.XXX_extensions == nil { + m.XXX_extensions = make(map[int32]proto.Extension) + } + return m.XXX_extensions +} + +func (m *MyMessage) GetCount() int32 { + if m != nil && m.Count != nil { + return *m.Count + } + return 0 +} + +func (m *MyMessage) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *MyMessage) GetQuote() string { + if m != nil && m.Quote != nil { + return *m.Quote + } + return "" +} + +func (m *MyMessage) GetPet() []string { + if m != nil { + return m.Pet + } + return nil +} + +func (m *MyMessage) GetInner() *InnerMessage { + if m != nil { + return m.Inner + } + return nil +} + +func (m *MyMessage) GetOthers() []*OtherMessage { + if m != nil { + return m.Others + } + return nil +} + +func (m *MyMessage) GetRepInner() []*InnerMessage { + if m != nil { + return m.RepInner + } + return nil +} + +func (m *MyMessage) GetBikeshed() MyMessage_Color { + if m != nil && m.Bikeshed != nil { + return *m.Bikeshed + } + return MyMessage_RED +} + +func (m *MyMessage) GetSomegroup() *MyMessage_SomeGroup { + if m != nil { + return m.Somegroup + } + return nil +} + +func (m *MyMessage) GetRepBytes() [][]byte { + if m != nil { + return m.RepBytes + } + return nil +} + +func (m *MyMessage) GetBigfloat() float64 { + if m != nil && m.Bigfloat != nil { + return *m.Bigfloat + } + return 0 +} + +type MyMessage_SomeGroup struct { + GroupField *int32 `protobuf:"varint,9,opt,name=group_field" json:"group_field,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MyMessage_SomeGroup) Reset() { *m = MyMessage_SomeGroup{} } +func (m *MyMessage_SomeGroup) String() string { return proto.CompactTextString(m) } +func (*MyMessage_SomeGroup) ProtoMessage() {} + +func (m *MyMessage_SomeGroup) GetGroupField() int32 { + if m != nil && m.GroupField != nil { + return *m.GroupField + } + return 0 +} + +type Ext struct { + Data *string `protobuf:"bytes,1,opt,name=data" json:"data,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Ext) Reset() { *m = Ext{} } +func (m *Ext) String() string { return proto.CompactTextString(m) } +func (*Ext) ProtoMessage() {} + +func (m *Ext) GetData() string { + if m != nil && m.Data != nil { + return *m.Data + } + return "" +} + +var E_Ext_More = &proto.ExtensionDesc{ + ExtendedType: (*MyMessage)(nil), + ExtensionType: (*Ext)(nil), + Field: 103, + Name: "testdata.Ext.more", + Tag: "bytes,103,opt,name=more", +} + +var E_Ext_Text = &proto.ExtensionDesc{ + ExtendedType: (*MyMessage)(nil), + ExtensionType: (*string)(nil), + Field: 104, + Name: "testdata.Ext.text", + Tag: "bytes,104,opt,name=text", +} + +var E_Ext_Number = &proto.ExtensionDesc{ + ExtendedType: (*MyMessage)(nil), + ExtensionType: (*int32)(nil), + Field: 105, + Name: "testdata.Ext.number", + Tag: "varint,105,opt,name=number", +} + +type MyMessageSet struct { + XXX_extensions map[int32]proto.Extension `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MyMessageSet) Reset() { *m = MyMessageSet{} } +func (m *MyMessageSet) String() string { return proto.CompactTextString(m) } +func (*MyMessageSet) ProtoMessage() {} + +func (m *MyMessageSet) Marshal() ([]byte, error) { + return proto.MarshalMessageSet(m.ExtensionMap()) +} +func (m *MyMessageSet) Unmarshal(buf []byte) error { + return proto.UnmarshalMessageSet(buf, m.ExtensionMap()) +} + +// ensure MyMessageSet satisfies proto.Marshaler and proto.Unmarshaler +var _ proto.Marshaler = (*MyMessageSet)(nil) +var _ proto.Unmarshaler = (*MyMessageSet)(nil) + +var extRange_MyMessageSet = []proto.ExtensionRange{ + {100, 2147483646}, +} + +func (*MyMessageSet) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_MyMessageSet +} +func (m *MyMessageSet) ExtensionMap() map[int32]proto.Extension { + if m.XXX_extensions == nil { + m.XXX_extensions = make(map[int32]proto.Extension) + } + return m.XXX_extensions +} + +type Empty struct { + XXX_unrecognized []byte `json:"-"` +} + +func (m *Empty) Reset() { *m = Empty{} } +func (m *Empty) String() string { return proto.CompactTextString(m) } +func (*Empty) ProtoMessage() {} + +type MessageList struct { + Message []*MessageList_Message `protobuf:"group,1,rep" json:"message,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MessageList) Reset() { *m = MessageList{} } +func (m *MessageList) String() string { return proto.CompactTextString(m) } +func (*MessageList) ProtoMessage() {} + +func (m *MessageList) GetMessage() []*MessageList_Message { + if m != nil { + return m.Message + } + return nil +} + +type MessageList_Message struct { + Name *string `protobuf:"bytes,2,req,name=name" json:"name,omitempty"` + Count *int32 `protobuf:"varint,3,req,name=count" json:"count,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MessageList_Message) Reset() { *m = MessageList_Message{} } +func (m *MessageList_Message) String() string { return proto.CompactTextString(m) } +func (*MessageList_Message) ProtoMessage() {} + +func (m *MessageList_Message) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *MessageList_Message) GetCount() int32 { + if m != nil && m.Count != nil { + return *m.Count + } + return 0 +} + +type Strings struct { + StringField *string `protobuf:"bytes,1,opt,name=string_field" json:"string_field,omitempty"` + BytesField []byte `protobuf:"bytes,2,opt,name=bytes_field" json:"bytes_field,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Strings) Reset() { *m = Strings{} } +func (m *Strings) String() string { return proto.CompactTextString(m) } +func (*Strings) ProtoMessage() {} + +func (m *Strings) GetStringField() string { + if m != nil && m.StringField != nil { + return *m.StringField + } + return "" +} + +func (m *Strings) GetBytesField() []byte { + if m != nil { + return m.BytesField + } + return nil +} + +type Defaults struct { + // Default-valued fields of all basic types. + // Same as GoTest, but copied here to make testing easier. + F_Bool *bool `protobuf:"varint,1,opt,def=1" json:"F_Bool,omitempty"` + F_Int32 *int32 `protobuf:"varint,2,opt,def=32" json:"F_Int32,omitempty"` + F_Int64 *int64 `protobuf:"varint,3,opt,def=64" json:"F_Int64,omitempty"` + F_Fixed32 *uint32 `protobuf:"fixed32,4,opt,def=320" json:"F_Fixed32,omitempty"` + F_Fixed64 *uint64 `protobuf:"fixed64,5,opt,def=640" json:"F_Fixed64,omitempty"` + F_Uint32 *uint32 `protobuf:"varint,6,opt,def=3200" json:"F_Uint32,omitempty"` + F_Uint64 *uint64 `protobuf:"varint,7,opt,def=6400" json:"F_Uint64,omitempty"` + F_Float *float32 `protobuf:"fixed32,8,opt,def=314159" json:"F_Float,omitempty"` + F_Double *float64 `protobuf:"fixed64,9,opt,def=271828" json:"F_Double,omitempty"` + F_String *string `protobuf:"bytes,10,opt,def=hello, \"world!\"\n" json:"F_String,omitempty"` + F_Bytes []byte `protobuf:"bytes,11,opt,def=Bignose" json:"F_Bytes,omitempty"` + F_Sint32 *int32 `protobuf:"zigzag32,12,opt,def=-32" json:"F_Sint32,omitempty"` + F_Sint64 *int64 `protobuf:"zigzag64,13,opt,def=-64" json:"F_Sint64,omitempty"` + F_Enum *Defaults_Color `protobuf:"varint,14,opt,enum=testdata.Defaults_Color,def=1" json:"F_Enum,omitempty"` + // More fields with crazy defaults. + F_Pinf *float32 `protobuf:"fixed32,15,opt,def=inf" json:"F_Pinf,omitempty"` + F_Ninf *float32 `protobuf:"fixed32,16,opt,def=-inf" json:"F_Ninf,omitempty"` + F_Nan *float32 `protobuf:"fixed32,17,opt,def=nan" json:"F_Nan,omitempty"` + // Sub-message. + Sub *SubDefaults `protobuf:"bytes,18,opt,name=sub" json:"sub,omitempty"` + // Redundant but explicit defaults. + StrZero *string `protobuf:"bytes,19,opt,name=str_zero,def=" json:"str_zero,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Defaults) Reset() { *m = Defaults{} } +func (m *Defaults) String() string { return proto.CompactTextString(m) } +func (*Defaults) ProtoMessage() {} + +const Default_Defaults_F_Bool bool = true +const Default_Defaults_F_Int32 int32 = 32 +const Default_Defaults_F_Int64 int64 = 64 +const Default_Defaults_F_Fixed32 uint32 = 320 +const Default_Defaults_F_Fixed64 uint64 = 640 +const Default_Defaults_F_Uint32 uint32 = 3200 +const Default_Defaults_F_Uint64 uint64 = 6400 +const Default_Defaults_F_Float float32 = 314159 +const Default_Defaults_F_Double float64 = 271828 +const Default_Defaults_F_String string = "hello, \"world!\"\n" + +var Default_Defaults_F_Bytes []byte = []byte("Bignose") + +const Default_Defaults_F_Sint32 int32 = -32 +const Default_Defaults_F_Sint64 int64 = -64 +const Default_Defaults_F_Enum Defaults_Color = Defaults_GREEN + +var Default_Defaults_F_Pinf float32 = float32(math.Inf(1)) +var Default_Defaults_F_Ninf float32 = float32(math.Inf(-1)) +var Default_Defaults_F_Nan float32 = float32(math.NaN()) + +func (m *Defaults) GetF_Bool() bool { + if m != nil && m.F_Bool != nil { + return *m.F_Bool + } + return Default_Defaults_F_Bool +} + +func (m *Defaults) GetF_Int32() int32 { + if m != nil && m.F_Int32 != nil { + return *m.F_Int32 + } + return Default_Defaults_F_Int32 +} + +func (m *Defaults) GetF_Int64() int64 { + if m != nil && m.F_Int64 != nil { + return *m.F_Int64 + } + return Default_Defaults_F_Int64 +} + +func (m *Defaults) GetF_Fixed32() uint32 { + if m != nil && m.F_Fixed32 != nil { + return *m.F_Fixed32 + } + return Default_Defaults_F_Fixed32 +} + +func (m *Defaults) GetF_Fixed64() uint64 { + if m != nil && m.F_Fixed64 != nil { + return *m.F_Fixed64 + } + return Default_Defaults_F_Fixed64 +} + +func (m *Defaults) GetF_Uint32() uint32 { + if m != nil && m.F_Uint32 != nil { + return *m.F_Uint32 + } + return Default_Defaults_F_Uint32 +} + +func (m *Defaults) GetF_Uint64() uint64 { + if m != nil && m.F_Uint64 != nil { + return *m.F_Uint64 + } + return Default_Defaults_F_Uint64 +} + +func (m *Defaults) GetF_Float() float32 { + if m != nil && m.F_Float != nil { + return *m.F_Float + } + return Default_Defaults_F_Float +} + +func (m *Defaults) GetF_Double() float64 { + if m != nil && m.F_Double != nil { + return *m.F_Double + } + return Default_Defaults_F_Double +} + +func (m *Defaults) GetF_String() string { + if m != nil && m.F_String != nil { + return *m.F_String + } + return Default_Defaults_F_String +} + +func (m *Defaults) GetF_Bytes() []byte { + if m != nil && m.F_Bytes != nil { + return m.F_Bytes + } + return append([]byte(nil), Default_Defaults_F_Bytes...) +} + +func (m *Defaults) GetF_Sint32() int32 { + if m != nil && m.F_Sint32 != nil { + return *m.F_Sint32 + } + return Default_Defaults_F_Sint32 +} + +func (m *Defaults) GetF_Sint64() int64 { + if m != nil && m.F_Sint64 != nil { + return *m.F_Sint64 + } + return Default_Defaults_F_Sint64 +} + +func (m *Defaults) GetF_Enum() Defaults_Color { + if m != nil && m.F_Enum != nil { + return *m.F_Enum + } + return Default_Defaults_F_Enum +} + +func (m *Defaults) GetF_Pinf() float32 { + if m != nil && m.F_Pinf != nil { + return *m.F_Pinf + } + return Default_Defaults_F_Pinf +} + +func (m *Defaults) GetF_Ninf() float32 { + if m != nil && m.F_Ninf != nil { + return *m.F_Ninf + } + return Default_Defaults_F_Ninf +} + +func (m *Defaults) GetF_Nan() float32 { + if m != nil && m.F_Nan != nil { + return *m.F_Nan + } + return Default_Defaults_F_Nan +} + +func (m *Defaults) GetSub() *SubDefaults { + if m != nil { + return m.Sub + } + return nil +} + +func (m *Defaults) GetStrZero() string { + if m != nil && m.StrZero != nil { + return *m.StrZero + } + return "" +} + +type SubDefaults struct { + N *int64 `protobuf:"varint,1,opt,name=n,def=7" json:"n,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SubDefaults) Reset() { *m = SubDefaults{} } +func (m *SubDefaults) String() string { return proto.CompactTextString(m) } +func (*SubDefaults) ProtoMessage() {} + +const Default_SubDefaults_N int64 = 7 + +func (m *SubDefaults) GetN() int64 { + if m != nil && m.N != nil { + return *m.N + } + return Default_SubDefaults_N +} + +type RepeatedEnum struct { + Color []RepeatedEnum_Color `protobuf:"varint,1,rep,name=color,enum=testdata.RepeatedEnum_Color" json:"color,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *RepeatedEnum) Reset() { *m = RepeatedEnum{} } +func (m *RepeatedEnum) String() string { return proto.CompactTextString(m) } +func (*RepeatedEnum) ProtoMessage() {} + +func (m *RepeatedEnum) GetColor() []RepeatedEnum_Color { + if m != nil { + return m.Color + } + return nil +} + +type MoreRepeated struct { + Bools []bool `protobuf:"varint,1,rep,name=bools" json:"bools,omitempty"` + BoolsPacked []bool `protobuf:"varint,2,rep,packed,name=bools_packed" json:"bools_packed,omitempty"` + Ints []int32 `protobuf:"varint,3,rep,name=ints" json:"ints,omitempty"` + IntsPacked []int32 `protobuf:"varint,4,rep,packed,name=ints_packed" json:"ints_packed,omitempty"` + Int64SPacked []int64 `protobuf:"varint,7,rep,packed,name=int64s_packed" json:"int64s_packed,omitempty"` + Strings []string `protobuf:"bytes,5,rep,name=strings" json:"strings,omitempty"` + Fixeds []uint32 `protobuf:"fixed32,6,rep,name=fixeds" json:"fixeds,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MoreRepeated) Reset() { *m = MoreRepeated{} } +func (m *MoreRepeated) String() string { return proto.CompactTextString(m) } +func (*MoreRepeated) ProtoMessage() {} + +func (m *MoreRepeated) GetBools() []bool { + if m != nil { + return m.Bools + } + return nil +} + +func (m *MoreRepeated) GetBoolsPacked() []bool { + if m != nil { + return m.BoolsPacked + } + return nil +} + +func (m *MoreRepeated) GetInts() []int32 { + if m != nil { + return m.Ints + } + return nil +} + +func (m *MoreRepeated) GetIntsPacked() []int32 { + if m != nil { + return m.IntsPacked + } + return nil +} + +func (m *MoreRepeated) GetInt64SPacked() []int64 { + if m != nil { + return m.Int64SPacked + } + return nil +} + +func (m *MoreRepeated) GetStrings() []string { + if m != nil { + return m.Strings + } + return nil +} + +func (m *MoreRepeated) GetFixeds() []uint32 { + if m != nil { + return m.Fixeds + } + return nil +} + +type GroupOld struct { + G *GroupOld_G `protobuf:"group,101,opt" json:"g,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GroupOld) Reset() { *m = GroupOld{} } +func (m *GroupOld) String() string { return proto.CompactTextString(m) } +func (*GroupOld) ProtoMessage() {} + +func (m *GroupOld) GetG() *GroupOld_G { + if m != nil { + return m.G + } + return nil +} + +type GroupOld_G struct { + X *int32 `protobuf:"varint,2,opt,name=x" json:"x,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GroupOld_G) Reset() { *m = GroupOld_G{} } +func (m *GroupOld_G) String() string { return proto.CompactTextString(m) } +func (*GroupOld_G) ProtoMessage() {} + +func (m *GroupOld_G) GetX() int32 { + if m != nil && m.X != nil { + return *m.X + } + return 0 +} + +type GroupNew struct { + G *GroupNew_G `protobuf:"group,101,opt" json:"g,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GroupNew) Reset() { *m = GroupNew{} } +func (m *GroupNew) String() string { return proto.CompactTextString(m) } +func (*GroupNew) ProtoMessage() {} + +func (m *GroupNew) GetG() *GroupNew_G { + if m != nil { + return m.G + } + return nil +} + +type GroupNew_G struct { + X *int32 `protobuf:"varint,2,opt,name=x" json:"x,omitempty"` + Y *int32 `protobuf:"varint,3,opt,name=y" json:"y,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GroupNew_G) Reset() { *m = GroupNew_G{} } +func (m *GroupNew_G) String() string { return proto.CompactTextString(m) } +func (*GroupNew_G) ProtoMessage() {} + +func (m *GroupNew_G) GetX() int32 { + if m != nil && m.X != nil { + return *m.X + } + return 0 +} + +func (m *GroupNew_G) GetY() int32 { + if m != nil && m.Y != nil { + return *m.Y + } + return 0 +} + +type FloatingPoint struct { + F *float64 `protobuf:"fixed64,1,req,name=f" json:"f,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *FloatingPoint) Reset() { *m = FloatingPoint{} } +func (m *FloatingPoint) String() string { return proto.CompactTextString(m) } +func (*FloatingPoint) ProtoMessage() {} + +func (m *FloatingPoint) GetF() float64 { + if m != nil && m.F != nil { + return *m.F + } + return 0 +} + +var E_Greeting = &proto.ExtensionDesc{ + ExtendedType: (*MyMessage)(nil), + ExtensionType: ([]string)(nil), + Field: 106, + Name: "testdata.greeting", + Tag: "bytes,106,rep,name=greeting", +} + +var E_X201 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 201, + Name: "testdata.x201", + Tag: "bytes,201,opt,name=x201", +} + +var E_X202 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 202, + Name: "testdata.x202", + Tag: "bytes,202,opt,name=x202", +} + +var E_X203 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 203, + Name: "testdata.x203", + Tag: "bytes,203,opt,name=x203", +} + +var E_X204 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 204, + Name: "testdata.x204", + Tag: "bytes,204,opt,name=x204", +} + +var E_X205 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 205, + Name: "testdata.x205", + Tag: "bytes,205,opt,name=x205", +} + +var E_X206 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 206, + Name: "testdata.x206", + Tag: "bytes,206,opt,name=x206", +} + +var E_X207 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 207, + Name: "testdata.x207", + Tag: "bytes,207,opt,name=x207", +} + +var E_X208 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 208, + Name: "testdata.x208", + Tag: "bytes,208,opt,name=x208", +} + +var E_X209 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 209, + Name: "testdata.x209", + Tag: "bytes,209,opt,name=x209", +} + +var E_X210 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 210, + Name: "testdata.x210", + Tag: "bytes,210,opt,name=x210", +} + +var E_X211 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 211, + Name: "testdata.x211", + Tag: "bytes,211,opt,name=x211", +} + +var E_X212 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 212, + Name: "testdata.x212", + Tag: "bytes,212,opt,name=x212", +} + +var E_X213 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 213, + Name: "testdata.x213", + Tag: "bytes,213,opt,name=x213", +} + +var E_X214 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 214, + Name: "testdata.x214", + Tag: "bytes,214,opt,name=x214", +} + +var E_X215 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 215, + Name: "testdata.x215", + Tag: "bytes,215,opt,name=x215", +} + +var E_X216 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 216, + Name: "testdata.x216", + Tag: "bytes,216,opt,name=x216", +} + +var E_X217 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 217, + Name: "testdata.x217", + Tag: "bytes,217,opt,name=x217", +} + +var E_X218 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 218, + Name: "testdata.x218", + Tag: "bytes,218,opt,name=x218", +} + +var E_X219 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 219, + Name: "testdata.x219", + Tag: "bytes,219,opt,name=x219", +} + +var E_X220 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 220, + Name: "testdata.x220", + Tag: "bytes,220,opt,name=x220", +} + +var E_X221 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 221, + Name: "testdata.x221", + Tag: "bytes,221,opt,name=x221", +} + +var E_X222 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 222, + Name: "testdata.x222", + Tag: "bytes,222,opt,name=x222", +} + +var E_X223 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 223, + Name: "testdata.x223", + Tag: "bytes,223,opt,name=x223", +} + +var E_X224 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 224, + Name: "testdata.x224", + Tag: "bytes,224,opt,name=x224", +} + +var E_X225 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 225, + Name: "testdata.x225", + Tag: "bytes,225,opt,name=x225", +} + +var E_X226 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 226, + Name: "testdata.x226", + Tag: "bytes,226,opt,name=x226", +} + +var E_X227 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 227, + Name: "testdata.x227", + Tag: "bytes,227,opt,name=x227", +} + +var E_X228 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 228, + Name: "testdata.x228", + Tag: "bytes,228,opt,name=x228", +} + +var E_X229 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 229, + Name: "testdata.x229", + Tag: "bytes,229,opt,name=x229", +} + +var E_X230 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 230, + Name: "testdata.x230", + Tag: "bytes,230,opt,name=x230", +} + +var E_X231 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 231, + Name: "testdata.x231", + Tag: "bytes,231,opt,name=x231", +} + +var E_X232 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 232, + Name: "testdata.x232", + Tag: "bytes,232,opt,name=x232", +} + +var E_X233 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 233, + Name: "testdata.x233", + Tag: "bytes,233,opt,name=x233", +} + +var E_X234 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 234, + Name: "testdata.x234", + Tag: "bytes,234,opt,name=x234", +} + +var E_X235 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 235, + Name: "testdata.x235", + Tag: "bytes,235,opt,name=x235", +} + +var E_X236 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 236, + Name: "testdata.x236", + Tag: "bytes,236,opt,name=x236", +} + +var E_X237 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 237, + Name: "testdata.x237", + Tag: "bytes,237,opt,name=x237", +} + +var E_X238 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 238, + Name: "testdata.x238", + Tag: "bytes,238,opt,name=x238", +} + +var E_X239 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 239, + Name: "testdata.x239", + Tag: "bytes,239,opt,name=x239", +} + +var E_X240 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 240, + Name: "testdata.x240", + Tag: "bytes,240,opt,name=x240", +} + +var E_X241 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 241, + Name: "testdata.x241", + Tag: "bytes,241,opt,name=x241", +} + +var E_X242 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 242, + Name: "testdata.x242", + Tag: "bytes,242,opt,name=x242", +} + +var E_X243 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 243, + Name: "testdata.x243", + Tag: "bytes,243,opt,name=x243", +} + +var E_X244 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 244, + Name: "testdata.x244", + Tag: "bytes,244,opt,name=x244", +} + +var E_X245 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 245, + Name: "testdata.x245", + Tag: "bytes,245,opt,name=x245", +} + +var E_X246 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 246, + Name: "testdata.x246", + Tag: "bytes,246,opt,name=x246", +} + +var E_X247 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 247, + Name: "testdata.x247", + Tag: "bytes,247,opt,name=x247", +} + +var E_X248 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 248, + Name: "testdata.x248", + Tag: "bytes,248,opt,name=x248", +} + +var E_X249 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 249, + Name: "testdata.x249", + Tag: "bytes,249,opt,name=x249", +} + +var E_X250 = &proto.ExtensionDesc{ + ExtendedType: (*MyMessageSet)(nil), + ExtensionType: (*Empty)(nil), + Field: 250, + Name: "testdata.x250", + Tag: "bytes,250,opt,name=x250", +} + +func init() { + proto.RegisterEnum("testdata.FOO", FOO_name, FOO_value) + proto.RegisterEnum("testdata.GoTest_KIND", GoTest_KIND_name, GoTest_KIND_value) + proto.RegisterEnum("testdata.MyMessage_Color", MyMessage_Color_name, MyMessage_Color_value) + proto.RegisterEnum("testdata.Defaults_Color", Defaults_Color_name, Defaults_Color_value) + proto.RegisterEnum("testdata.RepeatedEnum_Color", RepeatedEnum_Color_name, RepeatedEnum_Color_value) + proto.RegisterExtension(E_Ext_More) + proto.RegisterExtension(E_Ext_Text) + proto.RegisterExtension(E_Ext_Number) + proto.RegisterExtension(E_Greeting) + proto.RegisterExtension(E_X201) + proto.RegisterExtension(E_X202) + proto.RegisterExtension(E_X203) + proto.RegisterExtension(E_X204) + proto.RegisterExtension(E_X205) + proto.RegisterExtension(E_X206) + proto.RegisterExtension(E_X207) + proto.RegisterExtension(E_X208) + proto.RegisterExtension(E_X209) + proto.RegisterExtension(E_X210) + proto.RegisterExtension(E_X211) + proto.RegisterExtension(E_X212) + proto.RegisterExtension(E_X213) + proto.RegisterExtension(E_X214) + proto.RegisterExtension(E_X215) + proto.RegisterExtension(E_X216) + proto.RegisterExtension(E_X217) + proto.RegisterExtension(E_X218) + proto.RegisterExtension(E_X219) + proto.RegisterExtension(E_X220) + proto.RegisterExtension(E_X221) + proto.RegisterExtension(E_X222) + proto.RegisterExtension(E_X223) + proto.RegisterExtension(E_X224) + proto.RegisterExtension(E_X225) + proto.RegisterExtension(E_X226) + proto.RegisterExtension(E_X227) + proto.RegisterExtension(E_X228) + proto.RegisterExtension(E_X229) + proto.RegisterExtension(E_X230) + proto.RegisterExtension(E_X231) + proto.RegisterExtension(E_X232) + proto.RegisterExtension(E_X233) + proto.RegisterExtension(E_X234) + proto.RegisterExtension(E_X235) + proto.RegisterExtension(E_X236) + proto.RegisterExtension(E_X237) + proto.RegisterExtension(E_X238) + proto.RegisterExtension(E_X239) + proto.RegisterExtension(E_X240) + proto.RegisterExtension(E_X241) + proto.RegisterExtension(E_X242) + proto.RegisterExtension(E_X243) + proto.RegisterExtension(E_X244) + proto.RegisterExtension(E_X245) + proto.RegisterExtension(E_X246) + proto.RegisterExtension(E_X247) + proto.RegisterExtension(E_X248) + proto.RegisterExtension(E_X249) + proto.RegisterExtension(E_X250) +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.proto b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.proto new file mode 100644 index 0000000000..9c0c4c8970 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/testdata/test.proto @@ -0,0 +1,428 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A feature-rich test file for the protocol compiler and libraries. + +syntax = "proto2"; + +package testdata; + +enum FOO { FOO1 = 1; }; + +message GoEnum { + required FOO foo = 1; +} + +message GoTestField { + required string Label = 1; + required string Type = 2; +} + +message GoTest { + // An enum, for completeness. + enum KIND { + VOID = 0; + + // Basic types + BOOL = 1; + BYTES = 2; + FINGERPRINT = 3; + FLOAT = 4; + INT = 5; + STRING = 6; + TIME = 7; + + // Groupings + TUPLE = 8; + ARRAY = 9; + MAP = 10; + + // Table types + TABLE = 11; + + // Functions + FUNCTION = 12; // last tag + }; + + // Some typical parameters + required KIND Kind = 1; + optional string Table = 2; + optional int32 Param = 3; + + // Required, repeated and optional foreign fields. + required GoTestField RequiredField = 4; + repeated GoTestField RepeatedField = 5; + optional GoTestField OptionalField = 6; + + // Required fields of all basic types + required bool F_Bool_required = 10; + required int32 F_Int32_required = 11; + required int64 F_Int64_required = 12; + required fixed32 F_Fixed32_required = 13; + required fixed64 F_Fixed64_required = 14; + required uint32 F_Uint32_required = 15; + required uint64 F_Uint64_required = 16; + required float F_Float_required = 17; + required double F_Double_required = 18; + required string F_String_required = 19; + required bytes F_Bytes_required = 101; + required sint32 F_Sint32_required = 102; + required sint64 F_Sint64_required = 103; + + // Repeated fields of all basic types + repeated bool F_Bool_repeated = 20; + repeated int32 F_Int32_repeated = 21; + repeated int64 F_Int64_repeated = 22; + repeated fixed32 F_Fixed32_repeated = 23; + repeated fixed64 F_Fixed64_repeated = 24; + repeated uint32 F_Uint32_repeated = 25; + repeated uint64 F_Uint64_repeated = 26; + repeated float F_Float_repeated = 27; + repeated double F_Double_repeated = 28; + repeated string F_String_repeated = 29; + repeated bytes F_Bytes_repeated = 201; + repeated sint32 F_Sint32_repeated = 202; + repeated sint64 F_Sint64_repeated = 203; + + // Optional fields of all basic types + optional bool F_Bool_optional = 30; + optional int32 F_Int32_optional = 31; + optional int64 F_Int64_optional = 32; + optional fixed32 F_Fixed32_optional = 33; + optional fixed64 F_Fixed64_optional = 34; + optional uint32 F_Uint32_optional = 35; + optional uint64 F_Uint64_optional = 36; + optional float F_Float_optional = 37; + optional double F_Double_optional = 38; + optional string F_String_optional = 39; + optional bytes F_Bytes_optional = 301; + optional sint32 F_Sint32_optional = 302; + optional sint64 F_Sint64_optional = 303; + + // Default-valued fields of all basic types + optional bool F_Bool_defaulted = 40 [default=true]; + optional int32 F_Int32_defaulted = 41 [default=32]; + optional int64 F_Int64_defaulted = 42 [default=64]; + optional fixed32 F_Fixed32_defaulted = 43 [default=320]; + optional fixed64 F_Fixed64_defaulted = 44 [default=640]; + optional uint32 F_Uint32_defaulted = 45 [default=3200]; + optional uint64 F_Uint64_defaulted = 46 [default=6400]; + optional float F_Float_defaulted = 47 [default=314159.]; + optional double F_Double_defaulted = 48 [default=271828.]; + optional string F_String_defaulted = 49 [default="hello, \"world!\"\n"]; + optional bytes F_Bytes_defaulted = 401 [default="Bignose"]; + optional sint32 F_Sint32_defaulted = 402 [default = -32]; + optional sint64 F_Sint64_defaulted = 403 [default = -64]; + + // Packed repeated fields (no string or bytes). + repeated bool F_Bool_repeated_packed = 50 [packed=true]; + repeated int32 F_Int32_repeated_packed = 51 [packed=true]; + repeated int64 F_Int64_repeated_packed = 52 [packed=true]; + repeated fixed32 F_Fixed32_repeated_packed = 53 [packed=true]; + repeated fixed64 F_Fixed64_repeated_packed = 54 [packed=true]; + repeated uint32 F_Uint32_repeated_packed = 55 [packed=true]; + repeated uint64 F_Uint64_repeated_packed = 56 [packed=true]; + repeated float F_Float_repeated_packed = 57 [packed=true]; + repeated double F_Double_repeated_packed = 58 [packed=true]; + repeated sint32 F_Sint32_repeated_packed = 502 [packed=true]; + repeated sint64 F_Sint64_repeated_packed = 503 [packed=true]; + + // Required, repeated, and optional groups. + required group RequiredGroup = 70 { + required string RequiredField = 71; + }; + + repeated group RepeatedGroup = 80 { + required string RequiredField = 81; + }; + + optional group OptionalGroup = 90 { + required string RequiredField = 91; + }; +} + +// For testing skipping of unrecognized fields. +// Numbers are all big, larger than tag numbers in GoTestField, +// the message used in the corresponding test. +message GoSkipTest { + required int32 skip_int32 = 11; + required fixed32 skip_fixed32 = 12; + required fixed64 skip_fixed64 = 13; + required string skip_string = 14; + required group SkipGroup = 15 { + required int32 group_int32 = 16; + required string group_string = 17; + } +} + +// For testing packed/non-packed decoder switching. +// A serialized instance of one should be deserializable as the other. +message NonPackedTest { + repeated int32 a = 1; +} + +message PackedTest { + repeated int32 b = 1 [packed=true]; +} + +message MaxTag { + // Maximum possible tag number. + optional string last_field = 536870911; +} + +message OldMessage { + message Nested { + optional string name = 1; + } + optional Nested nested = 1; + + optional int32 num = 2; +} + +// NewMessage is wire compatible with OldMessage; +// imagine it as a future version. +message NewMessage { + message Nested { + optional string name = 1; + optional string food_group = 2; + } + optional Nested nested = 1; + + // This is an int32 in OldMessage. + optional int64 num = 2; +} + +// Smaller tests for ASCII formatting. + +message InnerMessage { + required string host = 1; + optional int32 port = 2 [default=4000]; + optional bool connected = 3; +} + +message OtherMessage { + optional int64 key = 1; + optional bytes value = 2; + optional float weight = 3; + optional InnerMessage inner = 4; +} + +message MyMessage { + required int32 count = 1; + optional string name = 2; + optional string quote = 3; + repeated string pet = 4; + optional InnerMessage inner = 5; + repeated OtherMessage others = 6; + repeated InnerMessage rep_inner = 12; + + enum Color { + RED = 0; + GREEN = 1; + BLUE = 2; + }; + optional Color bikeshed = 7; + + optional group SomeGroup = 8 { + optional int32 group_field = 9; + } + + // This field becomes [][]byte in the generated code. + repeated bytes rep_bytes = 10; + + optional double bigfloat = 11; + + extensions 100 to max; +} + +message Ext { + extend MyMessage { + optional Ext more = 103; + optional string text = 104; + optional int32 number = 105; + } + + optional string data = 1; +} + +extend MyMessage { + repeated string greeting = 106; +} + +message MyMessageSet { + option message_set_wire_format = true; + extensions 100 to max; +} + +message Empty { +} + +extend MyMessageSet { + optional Empty x201 = 201; + optional Empty x202 = 202; + optional Empty x203 = 203; + optional Empty x204 = 204; + optional Empty x205 = 205; + optional Empty x206 = 206; + optional Empty x207 = 207; + optional Empty x208 = 208; + optional Empty x209 = 209; + optional Empty x210 = 210; + optional Empty x211 = 211; + optional Empty x212 = 212; + optional Empty x213 = 213; + optional Empty x214 = 214; + optional Empty x215 = 215; + optional Empty x216 = 216; + optional Empty x217 = 217; + optional Empty x218 = 218; + optional Empty x219 = 219; + optional Empty x220 = 220; + optional Empty x221 = 221; + optional Empty x222 = 222; + optional Empty x223 = 223; + optional Empty x224 = 224; + optional Empty x225 = 225; + optional Empty x226 = 226; + optional Empty x227 = 227; + optional Empty x228 = 228; + optional Empty x229 = 229; + optional Empty x230 = 230; + optional Empty x231 = 231; + optional Empty x232 = 232; + optional Empty x233 = 233; + optional Empty x234 = 234; + optional Empty x235 = 235; + optional Empty x236 = 236; + optional Empty x237 = 237; + optional Empty x238 = 238; + optional Empty x239 = 239; + optional Empty x240 = 240; + optional Empty x241 = 241; + optional Empty x242 = 242; + optional Empty x243 = 243; + optional Empty x244 = 244; + optional Empty x245 = 245; + optional Empty x246 = 246; + optional Empty x247 = 247; + optional Empty x248 = 248; + optional Empty x249 = 249; + optional Empty x250 = 250; +} + +message MessageList { + repeated group Message = 1 { + required string name = 2; + required int32 count = 3; + } +} + +message Strings { + optional string string_field = 1; + optional bytes bytes_field = 2; +} + +message Defaults { + enum Color { + RED = 0; + GREEN = 1; + BLUE = 2; + } + + // Default-valued fields of all basic types. + // Same as GoTest, but copied here to make testing easier. + optional bool F_Bool = 1 [default=true]; + optional int32 F_Int32 = 2 [default=32]; + optional int64 F_Int64 = 3 [default=64]; + optional fixed32 F_Fixed32 = 4 [default=320]; + optional fixed64 F_Fixed64 = 5 [default=640]; + optional uint32 F_Uint32 = 6 [default=3200]; + optional uint64 F_Uint64 = 7 [default=6400]; + optional float F_Float = 8 [default=314159.]; + optional double F_Double = 9 [default=271828.]; + optional string F_String = 10 [default="hello, \"world!\"\n"]; + optional bytes F_Bytes = 11 [default="Bignose"]; + optional sint32 F_Sint32 = 12 [default=-32]; + optional sint64 F_Sint64 = 13 [default=-64]; + optional Color F_Enum = 14 [default=GREEN]; + + // More fields with crazy defaults. + optional float F_Pinf = 15 [default=inf]; + optional float F_Ninf = 16 [default=-inf]; + optional float F_Nan = 17 [default=nan]; + + // Sub-message. + optional SubDefaults sub = 18; + + // Redundant but explicit defaults. + optional string str_zero = 19 [default=""]; +} + +message SubDefaults { + optional int64 n = 1 [default=7]; +} + +message RepeatedEnum { + enum Color { + RED = 1; + } + repeated Color color = 1; +} + +message MoreRepeated { + repeated bool bools = 1; + repeated bool bools_packed = 2 [packed=true]; + repeated int32 ints = 3; + repeated int32 ints_packed = 4 [packed=true]; + repeated int64 int64s_packed = 7 [packed=true]; + repeated string strings = 5; + repeated fixed32 fixeds = 6; +} + +// GroupOld and GroupNew have the same wire format. +// GroupNew has a new field inside a group. + +message GroupOld { + optional group G = 101 { + optional int32 x = 2; + } +} + +message GroupNew { + optional group G = 101 { + optional int32 x = 2; + optional int32 y = 3; + } +} + +message FloatingPoint { + required double f = 1; +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text.go new file mode 100644 index 0000000000..5063009c27 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text.go @@ -0,0 +1,701 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +// Functions for writing the text protocol buffer format. + +import ( + "bufio" + "bytes" + "fmt" + "io" + "log" + "math" + "os" + "reflect" + "sort" + "strings" +) + +var ( + newline = []byte("\n") + spaces = []byte(" ") + gtNewline = []byte(">\n") + endBraceNewline = []byte("}\n") + backslashN = []byte{'\\', 'n'} + backslashR = []byte{'\\', 'r'} + backslashT = []byte{'\\', 't'} + backslashDQ = []byte{'\\', '"'} + backslashBS = []byte{'\\', '\\'} + posInf = []byte("inf") + negInf = []byte("-inf") + nan = []byte("nan") +) + +type writer interface { + io.Writer + WriteByte(byte) error +} + +// textWriter is an io.Writer that tracks its indentation level. +type textWriter struct { + ind int + complete bool // if the current position is a complete line + compact bool // whether to write out as a one-liner + w writer +} + +// textMarshaler is implemented by Messages that can marshal themsleves. +// It is identical to encoding.TextMarshaler, introduced in go 1.2, +// which will eventually replace it. +type textMarshaler interface { + MarshalText() (text []byte, err error) +} + +func (w *textWriter) WriteString(s string) (n int, err error) { + if !strings.Contains(s, "\n") { + if !w.compact && w.complete { + w.writeIndent() + } + w.complete = false + return io.WriteString(w.w, s) + } + // WriteString is typically called without newlines, so this + // codepath and its copy are rare. We copy to avoid + // duplicating all of Write's logic here. + return w.Write([]byte(s)) +} + +func (w *textWriter) Write(p []byte) (n int, err error) { + newlines := bytes.Count(p, newline) + if newlines == 0 { + if !w.compact && w.complete { + w.writeIndent() + } + n, err = w.w.Write(p) + w.complete = false + return n, err + } + + frags := bytes.SplitN(p, newline, newlines+1) + if w.compact { + for i, frag := range frags { + if i > 0 { + if err := w.w.WriteByte(' '); err != nil { + return n, err + } + n++ + } + nn, err := w.w.Write(frag) + n += nn + if err != nil { + return n, err + } + } + return n, nil + } + + for i, frag := range frags { + if w.complete { + w.writeIndent() + } + nn, err := w.w.Write(frag) + n += nn + if err != nil { + return n, err + } + if i+1 < len(frags) { + if err := w.w.WriteByte('\n'); err != nil { + return n, err + } + n++ + } + } + w.complete = len(frags[len(frags)-1]) == 0 + return n, nil +} + +func (w *textWriter) WriteByte(c byte) error { + if w.compact && c == '\n' { + c = ' ' + } + if !w.compact && w.complete { + w.writeIndent() + } + err := w.w.WriteByte(c) + w.complete = c == '\n' + return err +} + +func (w *textWriter) indent() { w.ind++ } + +func (w *textWriter) unindent() { + if w.ind == 0 { + log.Printf("proto: textWriter unindented too far") + return + } + w.ind-- +} + +func writeName(w *textWriter, props *Properties) error { + if _, err := w.WriteString(props.OrigName); err != nil { + return err + } + if props.Wire != "group" { + return w.WriteByte(':') + } + return nil +} + +var ( + messageSetType = reflect.TypeOf((*MessageSet)(nil)).Elem() +) + +// raw is the interface satisfied by RawMessage. +type raw interface { + Bytes() []byte +} + +func writeStruct(w *textWriter, sv reflect.Value) error { + if sv.Type() == messageSetType { + return writeMessageSet(w, sv.Addr().Interface().(*MessageSet)) + } + + st := sv.Type() + sprops := GetProperties(st) + for i := 0; i < sv.NumField(); i++ { + fv := sv.Field(i) + props := sprops.Prop[i] + name := st.Field(i).Name + + if strings.HasPrefix(name, "XXX_") { + // There are two XXX_ fields: + // XXX_unrecognized []byte + // XXX_extensions map[int32]proto.Extension + // The first is handled here; + // the second is handled at the bottom of this function. + if name == "XXX_unrecognized" && !fv.IsNil() { + if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil { + return err + } + } + continue + } + if fv.Kind() == reflect.Ptr && fv.IsNil() { + // Field not filled in. This could be an optional field or + // a required field that wasn't filled in. Either way, there + // isn't anything we can show for it. + continue + } + if fv.Kind() == reflect.Slice && fv.IsNil() { + // Repeated field that is empty, or a bytes field that is unused. + continue + } + + if props.Repeated && fv.Kind() == reflect.Slice { + // Repeated field. + for j := 0; j < fv.Len(); j++ { + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + v := fv.Index(j) + if v.Kind() == reflect.Ptr && v.IsNil() { + // A nil message in a repeated field is not valid, + // but we can handle that more gracefully than panicking. + if _, err := w.Write([]byte("\n")); err != nil { + return err + } + continue + } + if err := writeAny(w, v, props); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + continue + } + + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if b, ok := fv.Interface().(raw); ok { + if err := writeRaw(w, b.Bytes()); err != nil { + return err + } + continue + } + + // Enums have a String method, so writeAny will work fine. + if err := writeAny(w, fv, props); err != nil { + return err + } + + if err := w.WriteByte('\n'); err != nil { + return err + } + } + + // Extensions (the XXX_extensions field). + pv := sv.Addr() + if pv.Type().Implements(extendableProtoType) { + if err := writeExtensions(w, pv); err != nil { + return err + } + } + + return nil +} + +// writeRaw writes an uninterpreted raw message. +func writeRaw(w *textWriter, b []byte) error { + if err := w.WriteByte('<'); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + if err := writeUnknownStruct(w, b); err != nil { + return err + } + w.unindent() + if err := w.WriteByte('>'); err != nil { + return err + } + return nil +} + +// writeAny writes an arbitrary field. +func writeAny(w *textWriter, v reflect.Value, props *Properties) error { + v = reflect.Indirect(v) + + // Floats have special cases. + if v.Kind() == reflect.Float32 || v.Kind() == reflect.Float64 { + x := v.Float() + var b []byte + switch { + case math.IsInf(x, 1): + b = posInf + case math.IsInf(x, -1): + b = negInf + case math.IsNaN(x): + b = nan + } + if b != nil { + _, err := w.Write(b) + return err + } + // Other values are handled below. + } + + // We don't attempt to serialise every possible value type; only those + // that can occur in protocol buffers. + switch v.Kind() { + case reflect.Slice: + // Should only be a []byte; repeated fields are handled in writeStruct. + if err := writeString(w, string(v.Interface().([]byte))); err != nil { + return err + } + case reflect.String: + if err := writeString(w, v.String()); err != nil { + return err + } + case reflect.Struct: + // Required/optional group/message. + var bra, ket byte = '<', '>' + if props != nil && props.Wire == "group" { + bra, ket = '{', '}' + } + if err := w.WriteByte(bra); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + if tm, ok := v.Interface().(textMarshaler); ok { + text, err := tm.MarshalText() + if err != nil { + return err + } + if _, err = w.Write(text); err != nil { + return err + } + } else if err := writeStruct(w, v); err != nil { + return err + } + w.unindent() + if err := w.WriteByte(ket); err != nil { + return err + } + default: + _, err := fmt.Fprint(w, v.Interface()) + return err + } + return nil +} + +// equivalent to C's isprint. +func isprint(c byte) bool { + return c >= 0x20 && c < 0x7f +} + +// writeString writes a string in the protocol buffer text format. +// It is similar to strconv.Quote except we don't use Go escape sequences, +// we treat the string as a byte sequence, and we use octal escapes. +// These differences are to maintain interoperability with the other +// languages' implementations of the text format. +func writeString(w *textWriter, s string) error { + // use WriteByte here to get any needed indent + if err := w.WriteByte('"'); err != nil { + return err + } + // Loop over the bytes, not the runes. + for i := 0; i < len(s); i++ { + var err error + // Divergence from C++: we don't escape apostrophes. + // There's no need to escape them, and the C++ parser + // copes with a naked apostrophe. + switch c := s[i]; c { + case '\n': + _, err = w.w.Write(backslashN) + case '\r': + _, err = w.w.Write(backslashR) + case '\t': + _, err = w.w.Write(backslashT) + case '"': + _, err = w.w.Write(backslashDQ) + case '\\': + _, err = w.w.Write(backslashBS) + default: + if isprint(c) { + err = w.w.WriteByte(c) + } else { + _, err = fmt.Fprintf(w.w, "\\%03o", c) + } + } + if err != nil { + return err + } + } + return w.WriteByte('"') +} + +func writeMessageSet(w *textWriter, ms *MessageSet) error { + for _, item := range ms.Item { + id := *item.TypeId + if msd, ok := messageSetMap[id]; ok { + // Known message set type. + if _, err := fmt.Fprintf(w, "[%s]: <\n", msd.name); err != nil { + return err + } + w.indent() + + pb := reflect.New(msd.t.Elem()) + if err := Unmarshal(item.Message, pb.Interface().(Message)); err != nil { + if _, err := fmt.Fprintf(w, "/* bad message: %v */\n", err); err != nil { + return err + } + } else { + if err := writeStruct(w, pb.Elem()); err != nil { + return err + } + } + } else { + // Unknown type. + if _, err := fmt.Fprintf(w, "[%d]: <\n", id); err != nil { + return err + } + w.indent() + if err := writeUnknownStruct(w, item.Message); err != nil { + return err + } + } + w.unindent() + if _, err := w.Write(gtNewline); err != nil { + return err + } + } + return nil +} + +func writeUnknownStruct(w *textWriter, data []byte) (err error) { + if !w.compact { + if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil { + return err + } + } + b := NewBuffer(data) + for b.index < len(b.buf) { + x, err := b.DecodeVarint() + if err != nil { + _, err := fmt.Fprintf(w, "/* %v */\n", err) + return err + } + wire, tag := x&7, x>>3 + if wire == WireEndGroup { + w.unindent() + if _, err := w.Write(endBraceNewline); err != nil { + return err + } + continue + } + if _, err := fmt.Fprint(w, tag); err != nil { + return err + } + if wire != WireStartGroup { + if err := w.WriteByte(':'); err != nil { + return err + } + } + if !w.compact || wire == WireStartGroup { + if err := w.WriteByte(' '); err != nil { + return err + } + } + switch wire { + case WireBytes: + buf, e := b.DecodeRawBytes(false) + if e == nil { + _, err = fmt.Fprintf(w, "%q", buf) + } else { + _, err = fmt.Fprintf(w, "/* %v */", e) + } + case WireFixed32: + x, err = b.DecodeFixed32() + err = writeUnknownInt(w, x, err) + case WireFixed64: + x, err = b.DecodeFixed64() + err = writeUnknownInt(w, x, err) + case WireStartGroup: + err = w.WriteByte('{') + w.indent() + case WireVarint: + x, err = b.DecodeVarint() + err = writeUnknownInt(w, x, err) + default: + _, err = fmt.Fprintf(w, "/* unknown wire type %d */", wire) + } + if err != nil { + return err + } + if err = w.WriteByte('\n'); err != nil { + return err + } + } + return nil +} + +func writeUnknownInt(w *textWriter, x uint64, err error) error { + if err == nil { + _, err = fmt.Fprint(w, x) + } else { + _, err = fmt.Fprintf(w, "/* %v */", err) + } + return err +} + +type int32Slice []int32 + +func (s int32Slice) Len() int { return len(s) } +func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] } +func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// writeExtensions writes all the extensions in pv. +// pv is assumed to be a pointer to a protocol message struct that is extendable. +func writeExtensions(w *textWriter, pv reflect.Value) error { + emap := extensionMaps[pv.Type().Elem()] + ep := pv.Interface().(extendableProto) + + // Order the extensions by ID. + // This isn't strictly necessary, but it will give us + // canonical output, which will also make testing easier. + m := ep.ExtensionMap() + ids := make([]int32, 0, len(m)) + for id := range m { + ids = append(ids, id) + } + sort.Sort(int32Slice(ids)) + + for _, extNum := range ids { + ext := m[extNum] + var desc *ExtensionDesc + if emap != nil { + desc = emap[extNum] + } + if desc == nil { + // Unknown extension. + if err := writeUnknownStruct(w, ext.enc); err != nil { + return err + } + continue + } + + pb, err := GetExtension(ep, desc) + if err != nil { + if _, err := fmt.Fprintln(os.Stderr, "proto: failed getting extension: ", err); err != nil { + return err + } + continue + } + + // Repeated extensions will appear as a slice. + if !desc.repeated() { + if err := writeExtension(w, desc.Name, pb); err != nil { + return err + } + } else { + v := reflect.ValueOf(pb) + for i := 0; i < v.Len(); i++ { + if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil { + return err + } + } + } + } + return nil +} + +func writeExtension(w *textWriter, name string, pb interface{}) error { + if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + return nil +} + +func (w *textWriter) writeIndent() { + if !w.complete { + return + } + remain := w.ind * 2 + for remain > 0 { + n := remain + if n > len(spaces) { + n = len(spaces) + } + w.w.Write(spaces[:n]) + remain -= n + } + w.complete = false +} + +func marshalText(w io.Writer, pb Message, compact bool) error { + val := reflect.ValueOf(pb) + if pb == nil || val.IsNil() { + w.Write([]byte("")) + return nil + } + var bw *bufio.Writer + ww, ok := w.(writer) + if !ok { + bw = bufio.NewWriter(w) + ww = bw + } + aw := &textWriter{ + w: ww, + complete: true, + compact: compact, + } + + if tm, ok := pb.(textMarshaler); ok { + text, err := tm.MarshalText() + if err != nil { + return err + } + if _, err = aw.Write(text); err != nil { + return err + } + if bw != nil { + return bw.Flush() + } + return nil + } + // Dereference the received pointer so we don't have outer < and >. + v := reflect.Indirect(val) + if err := writeStruct(aw, v); err != nil { + return err + } + if bw != nil { + return bw.Flush() + } + return nil +} + +// MarshalText writes a given protocol buffer in text format. +// The only errors returned are from w. +func MarshalText(w io.Writer, pb Message) error { + return marshalText(w, pb, false) +} + +// MarshalTextString is the same as MarshalText, but returns the string directly. +func MarshalTextString(pb Message) string { + var buf bytes.Buffer + marshalText(&buf, pb, false) + return buf.String() +} + +// CompactText writes a given protocol buffer in compact text format (one line). +func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) } + +// CompactTextString is the same as CompactText, but returns the string directly. +func CompactTextString(pb Message) string { + var buf bytes.Buffer + marshalText(&buf, pb, true) + return buf.String() +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser.go new file mode 100644 index 0000000000..d235ca9bc8 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser.go @@ -0,0 +1,684 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +// Functions for parsing the Text protocol buffer format. +// TODO: message sets. + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "unicode/utf8" +) + +// textUnmarshaler is implemented by Messages that can unmarshal themsleves. +// It is identical to encoding.TextUnmarshaler, introduced in go 1.2, +// which will eventually replace it. +type textUnmarshaler interface { + UnmarshalText(text []byte) error +} + +type ParseError struct { + Message string + Line int // 1-based line number + Offset int // 0-based byte offset from start of input +} + +func (p *ParseError) Error() string { + if p.Line == 1 { + // show offset only for first line + return fmt.Sprintf("line 1.%d: %v", p.Offset, p.Message) + } + return fmt.Sprintf("line %d: %v", p.Line, p.Message) +} + +type token struct { + value string + err *ParseError + line int // line number + offset int // byte number from start of input, not start of line + unquoted string // the unquoted version of value, if it was a quoted string +} + +func (t *token) String() string { + if t.err == nil { + return fmt.Sprintf("%q (line=%d, offset=%d)", t.value, t.line, t.offset) + } + return fmt.Sprintf("parse error: %v", t.err) +} + +type textParser struct { + s string // remaining input + done bool // whether the parsing is finished (success or error) + backed bool // whether back() was called + offset, line int + cur token +} + +func newTextParser(s string) *textParser { + p := new(textParser) + p.s = s + p.line = 1 + p.cur.line = 1 + return p +} + +func (p *textParser) errorf(format string, a ...interface{}) *ParseError { + pe := &ParseError{fmt.Sprintf(format, a...), p.cur.line, p.cur.offset} + p.cur.err = pe + p.done = true + return pe +} + +// Numbers and identifiers are matched by [-+._A-Za-z0-9] +func isIdentOrNumberChar(c byte) bool { + switch { + case 'A' <= c && c <= 'Z', 'a' <= c && c <= 'z': + return true + case '0' <= c && c <= '9': + return true + } + switch c { + case '-', '+', '.', '_': + return true + } + return false +} + +func isWhitespace(c byte) bool { + switch c { + case ' ', '\t', '\n', '\r': + return true + } + return false +} + +func (p *textParser) skipWhitespace() { + i := 0 + for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') { + if p.s[i] == '#' { + // comment; skip to end of line or input + for i < len(p.s) && p.s[i] != '\n' { + i++ + } + if i == len(p.s) { + break + } + } + if p.s[i] == '\n' { + p.line++ + } + i++ + } + p.offset += i + p.s = p.s[i:len(p.s)] + if len(p.s) == 0 { + p.done = true + } +} + +func (p *textParser) advance() { + // Skip whitespace + p.skipWhitespace() + if p.done { + return + } + + // Start of non-whitespace + p.cur.err = nil + p.cur.offset, p.cur.line = p.offset, p.line + p.cur.unquoted = "" + switch p.s[0] { + case '<', '>', '{', '}', ':', '[', ']', ';', ',': + // Single symbol + p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)] + case '"', '\'': + // Quoted string + i := 1 + for i < len(p.s) && p.s[i] != p.s[0] && p.s[i] != '\n' { + if p.s[i] == '\\' && i+1 < len(p.s) { + // skip escaped char + i++ + } + i++ + } + if i >= len(p.s) || p.s[i] != p.s[0] { + p.errorf("unmatched quote") + return + } + unq, err := unquoteC(p.s[1:i], rune(p.s[0])) + if err != nil { + p.errorf("invalid quoted string %v", p.s[0:i+1]) + return + } + p.cur.value, p.s = p.s[0:i+1], p.s[i+1:len(p.s)] + p.cur.unquoted = unq + default: + i := 0 + for i < len(p.s) && isIdentOrNumberChar(p.s[i]) { + i++ + } + if i == 0 { + p.errorf("unexpected byte %#x", p.s[0]) + return + } + p.cur.value, p.s = p.s[0:i], p.s[i:len(p.s)] + } + p.offset += len(p.cur.value) +} + +var ( + errBadUTF8 = errors.New("proto: bad UTF-8") + errBadHex = errors.New("proto: bad hexadecimal") +) + +func unquoteC(s string, quote rune) (string, error) { + // This is based on C++'s tokenizer.cc. + // Despite its name, this is *not* parsing C syntax. + // For instance, "\0" is an invalid quoted string. + + // Avoid allocation in trivial cases. + simple := true + for _, r := range s { + if r == '\\' || r == quote { + simple = false + break + } + } + if simple { + return s, nil + } + + buf := make([]byte, 0, 3*len(s)/2) + for len(s) > 0 { + r, n := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && n == 1 { + return "", errBadUTF8 + } + s = s[n:] + if r != '\\' { + if r < utf8.RuneSelf { + buf = append(buf, byte(r)) + } else { + buf = append(buf, string(r)...) + } + continue + } + + ch, tail, err := unescape(s) + if err != nil { + return "", err + } + buf = append(buf, ch...) + s = tail + } + return string(buf), nil +} + +func unescape(s string) (ch string, tail string, err error) { + r, n := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && n == 1 { + return "", "", errBadUTF8 + } + s = s[n:] + switch r { + case 'a': + return "\a", s, nil + case 'b': + return "\b", s, nil + case 'f': + return "\f", s, nil + case 'n': + return "\n", s, nil + case 'r': + return "\r", s, nil + case 't': + return "\t", s, nil + case 'v': + return "\v", s, nil + case '?': + return "?", s, nil // trigraph workaround + case '\'', '"', '\\': + return string(r), s, nil + case '0', '1', '2', '3', '4', '5', '6', '7', 'x', 'X': + if len(s) < 2 { + return "", "", fmt.Errorf(`\%c requires 2 following digits`, r) + } + base := 8 + ss := s[:2] + s = s[2:] + if r == 'x' || r == 'X' { + base = 16 + } else { + ss = string(r) + ss + } + i, err := strconv.ParseUint(ss, base, 8) + if err != nil { + return "", "", err + } + return string([]byte{byte(i)}), s, nil + case 'u', 'U': + n := 4 + if r == 'U' { + n = 8 + } + if len(s) < n { + return "", "", fmt.Errorf(`\%c requires %d digits`, r, n) + } + + bs := make([]byte, n/2) + for i := 0; i < n; i += 2 { + a, ok1 := unhex(s[i]) + b, ok2 := unhex(s[i+1]) + if !ok1 || !ok2 { + return "", "", errBadHex + } + bs[i/2] = a<<4 | b + } + s = s[n:] + return string(bs), s, nil + } + return "", "", fmt.Errorf(`unknown escape \%c`, r) +} + +// Adapted from src/pkg/strconv/quote.go. +func unhex(b byte) (v byte, ok bool) { + switch { + case '0' <= b && b <= '9': + return b - '0', true + case 'a' <= b && b <= 'f': + return b - 'a' + 10, true + case 'A' <= b && b <= 'F': + return b - 'A' + 10, true + } + return 0, false +} + +// Back off the parser by one token. Can only be done between calls to next(). +// It makes the next advance() a no-op. +func (p *textParser) back() { p.backed = true } + +// Advances the parser and returns the new current token. +func (p *textParser) next() *token { + if p.backed || p.done { + p.backed = false + return &p.cur + } + p.advance() + if p.done { + p.cur.value = "" + } else if len(p.cur.value) > 0 && p.cur.value[0] == '"' { + // Look for multiple quoted strings separated by whitespace, + // and concatenate them. + cat := p.cur + for { + p.skipWhitespace() + if p.done || p.s[0] != '"' { + break + } + p.advance() + if p.cur.err != nil { + return &p.cur + } + cat.value += " " + p.cur.value + cat.unquoted += p.cur.unquoted + } + p.done = false // parser may have seen EOF, but we want to return cat + p.cur = cat + } + return &p.cur +} + +// Return an error indicating which required field was not set. +func (p *textParser) missingRequiredFieldError(sv reflect.Value) *ParseError { + st := sv.Type() + sprops := GetProperties(st) + for i := 0; i < st.NumField(); i++ { + if !isNil(sv.Field(i)) { + continue + } + + props := sprops.Prop[i] + if props.Required { + return p.errorf("message %v missing required field %q", st, props.OrigName) + } + } + return p.errorf("message %v missing required field", st) // should not happen +} + +// Returns the index in the struct for the named field, as well as the parsed tag properties. +func structFieldByName(st reflect.Type, name string) (int, *Properties, bool) { + sprops := GetProperties(st) + i, ok := sprops.decoderOrigNames[name] + if ok { + return i, sprops.Prop[i], true + } + return -1, nil, false +} + +// Consume a ':' from the input stream (if the next token is a colon), +// returning an error if a colon is needed but not present. +func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseError { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != ":" { + // Colon is optional when the field is a group or message. + needColon := true + switch props.Wire { + case "group": + needColon = false + case "bytes": + // A "bytes" field is either a message, a string, or a repeated field; + // those three become *T, *string and []T respectively, so we can check for + // this field being a pointer to a non-string. + if typ.Kind() == reflect.Ptr { + // *T or *string + if typ.Elem().Kind() == reflect.String { + break + } + } else if typ.Kind() == reflect.Slice { + // []T or []*T + if typ.Elem().Kind() != reflect.Ptr { + break + } + } + needColon = false + } + if needColon { + return p.errorf("expected ':', found %q", tok.value) + } + p.back() + } + return nil +} + +func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError { + st := sv.Type() + reqCount := GetProperties(st).reqCount + // A struct is a sequence of "name: value", terminated by one of + // '>' or '}', or the end of the input. A name may also be + // "[extension]". + for { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == terminator { + break + } + if tok.value == "[" { + // Looks like an extension. + // + // TODO: Check whether we need to handle + // namespace rooted names (e.g. ".something.Foo"). + tok = p.next() + if tok.err != nil { + return tok.err + } + var desc *ExtensionDesc + // This could be faster, but it's functional. + // TODO: Do something smarter than a linear scan. + for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) { + if d.Name == tok.value { + desc = d + break + } + } + if desc == nil { + return p.errorf("unrecognized extension %q", tok.value) + } + // Check the extension terminator. + tok = p.next() + if tok.err != nil { + return tok.err + } + if tok.value != "]" { + return p.errorf("unrecognized extension terminator %q", tok.value) + } + + props := &Properties{} + props.Parse(desc.Tag) + + typ := reflect.TypeOf(desc.ExtensionType) + if err := p.checkForColon(props, typ); err != nil { + return err + } + + rep := desc.repeated() + + // Read the extension structure, and set it in + // the value we're constructing. + var ext reflect.Value + if !rep { + ext = reflect.New(typ).Elem() + } else { + ext = reflect.New(typ.Elem()).Elem() + } + if err := p.readAny(ext, props); err != nil { + return err + } + ep := sv.Addr().Interface().(extendableProto) + if !rep { + SetExtension(ep, desc, ext.Interface()) + } else { + old, err := GetExtension(ep, desc) + var sl reflect.Value + if err == nil { + sl = reflect.ValueOf(old) // existing slice + } else { + sl = reflect.MakeSlice(typ, 0, 1) + } + sl = reflect.Append(sl, ext) + SetExtension(ep, desc, sl.Interface()) + } + } else { + // This is a normal, non-extension field. + fi, props, ok := structFieldByName(st, tok.value) + if !ok { + return p.errorf("unknown field name %q in %v", tok.value, st) + } + + dst := sv.Field(fi) + isDstNil := isNil(dst) + + // Check that it's not already set if it's not a repeated field. + if !props.Repeated && !isDstNil { + return p.errorf("non-repeated field %q was repeated", tok.value) + } + + if err := p.checkForColon(props, st.Field(fi).Type); err != nil { + return err + } + + // Parse into the field. + if err := p.readAny(dst, props); err != nil { + return err + } + + if props.Required { + reqCount-- + } + } + + // For backward compatibility, permit a semicolon or comma after a field. + tok = p.next() + if tok.err != nil { + return tok.err + } + if tok.value != ";" && tok.value != "," { + p.back() + } + } + + if reqCount > 0 { + return p.missingRequiredFieldError(sv) + } + return nil +} + +func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == "" { + return p.errorf("unexpected EOF") + } + + switch fv := v; fv.Kind() { + case reflect.Slice: + at := v.Type() + if at.Elem().Kind() == reflect.Uint8 { + // Special case for []byte + if tok.value[0] != '"' && tok.value[0] != '\'' { + // Deliberately written out here, as the error after + // this switch statement would write "invalid []byte: ...", + // which is not as user-friendly. + return p.errorf("invalid string: %v", tok.value) + } + bytes := []byte(tok.unquoted) + fv.Set(reflect.ValueOf(bytes)) + return nil + } + // Repeated field. May already exist. + flen := fv.Len() + if flen == fv.Cap() { + nav := reflect.MakeSlice(at, flen, 2*flen+1) + reflect.Copy(nav, fv) + fv.Set(nav) + } + fv.SetLen(flen + 1) + + // Read one. + p.back() + return p.readAny(fv.Index(flen), props) + case reflect.Bool: + // Either "true", "false", 1 or 0. + switch tok.value { + case "true", "1": + fv.SetBool(true) + return nil + case "false", "0": + fv.SetBool(false) + return nil + } + case reflect.Float32, reflect.Float64: + v := tok.value + // Ignore 'f' for compatibility with output generated by C++, but don't + // remove 'f' when the value is "-inf" or "inf". + if strings.HasSuffix(v, "f") && tok.value != "-inf" && tok.value != "inf" { + v = v[:len(v)-1] + } + if f, err := strconv.ParseFloat(v, fv.Type().Bits()); err == nil { + fv.SetFloat(f) + return nil + } + case reflect.Int32: + if x, err := strconv.ParseInt(tok.value, 0, 32); err == nil { + fv.SetInt(x) + return nil + } + + if len(props.Enum) == 0 { + break + } + m, ok := enumValueMaps[props.Enum] + if !ok { + break + } + x, ok := m[tok.value] + if !ok { + break + } + fv.SetInt(int64(x)) + return nil + case reflect.Int64: + if x, err := strconv.ParseInt(tok.value, 0, 64); err == nil { + fv.SetInt(x) + return nil + } + + case reflect.Ptr: + // A basic field (indirected through pointer), or a repeated message/group + p.back() + fv.Set(reflect.New(fv.Type().Elem())) + return p.readAny(fv.Elem(), props) + case reflect.String: + if tok.value[0] == '"' || tok.value[0] == '\'' { + fv.SetString(tok.unquoted) + return nil + } + case reflect.Struct: + var terminator string + switch tok.value { + case "{": + terminator = "}" + case "<": + terminator = ">" + default: + return p.errorf("expected '{' or '<', found %q", tok.value) + } + // TODO: Handle nested messages which implement textUnmarshaler. + return p.readStruct(fv, terminator) + case reflect.Uint32: + if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil { + fv.SetUint(uint64(x)) + return nil + } + case reflect.Uint64: + if x, err := strconv.ParseUint(tok.value, 0, 64); err == nil { + fv.SetUint(x) + return nil + } + } + return p.errorf("invalid %v: %v", v.Type(), tok.value) +} + +// UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb +// before starting to unmarshal, so any existing data in pb is always removed. +func UnmarshalText(s string, pb Message) error { + if um, ok := pb.(textUnmarshaler); ok { + err := um.UnmarshalText([]byte(s)) + return err + } + pb.Reset() + v := reflect.ValueOf(pb) + if pe := newTextParser(s).readStruct(v.Elem(), ""); pe != nil { + return pe + } + return nil +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser_test.go new file mode 100644 index 0000000000..d043769ec1 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_parser_test.go @@ -0,0 +1,462 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "math" + "reflect" + "testing" + + . "./testdata" + . "code.google.com/p/goprotobuf/proto" +) + +type UnmarshalTextTest struct { + in string + err string // if "", no error expected + out *MyMessage +} + +func buildExtStructTest(text string) UnmarshalTextTest { + msg := &MyMessage{ + Count: Int32(42), + } + SetExtension(msg, E_Ext_More, &Ext{ + Data: String("Hello, world!"), + }) + return UnmarshalTextTest{in: text, out: msg} +} + +func buildExtDataTest(text string) UnmarshalTextTest { + msg := &MyMessage{ + Count: Int32(42), + } + SetExtension(msg, E_Ext_Text, String("Hello, world!")) + SetExtension(msg, E_Ext_Number, Int32(1729)) + return UnmarshalTextTest{in: text, out: msg} +} + +func buildExtRepStringTest(text string) UnmarshalTextTest { + msg := &MyMessage{ + Count: Int32(42), + } + if err := SetExtension(msg, E_Greeting, []string{"bula", "hola"}); err != nil { + panic(err) + } + return UnmarshalTextTest{in: text, out: msg} +} + +var unMarshalTextTests = []UnmarshalTextTest{ + // Basic + { + in: " count:42\n name:\"Dave\" ", + out: &MyMessage{ + Count: Int32(42), + Name: String("Dave"), + }, + }, + + // Empty quoted string + { + in: `count:42 name:""`, + out: &MyMessage{ + Count: Int32(42), + Name: String(""), + }, + }, + + // Quoted string concatenation + { + in: `count:42 name: "My name is "` + "\n" + `"elsewhere"`, + out: &MyMessage{ + Count: Int32(42), + Name: String("My name is elsewhere"), + }, + }, + + // Quoted string with escaped apostrophe + { + in: `count:42 name: "HOLIDAY - New Year\'s Day"`, + out: &MyMessage{ + Count: Int32(42), + Name: String("HOLIDAY - New Year's Day"), + }, + }, + + // Quoted string with single quote + { + in: `count:42 name: 'Roger "The Ramster" Ramjet'`, + out: &MyMessage{ + Count: Int32(42), + Name: String(`Roger "The Ramster" Ramjet`), + }, + }, + + // Quoted string with all the accepted special characters from the C++ test + { + in: `count:42 name: ` + "\"\\\"A string with \\' characters \\n and \\r newlines and \\t tabs and \\001 slashes \\\\ and multiple spaces\"", + out: &MyMessage{ + Count: Int32(42), + Name: String("\"A string with ' characters \n and \r newlines and \t tabs and \001 slashes \\ and multiple spaces"), + }, + }, + + // Quoted string with quoted backslash + { + in: `count:42 name: "\\'xyz"`, + out: &MyMessage{ + Count: Int32(42), + Name: String(`\'xyz`), + }, + }, + + // Quoted string with UTF-8 bytes. + { + in: "count:42 name: '\303\277\302\201\xAB'", + out: &MyMessage{ + Count: Int32(42), + Name: String("\303\277\302\201\xAB"), + }, + }, + + // Bad quoted string + { + in: `inner: < host: "\0" >` + "\n", + err: `line 1.15: invalid quoted string "\0"`, + }, + + // Number too large for int64 + { + in: "count: 1 others { key: 123456789012345678901 }", + err: "line 1.23: invalid int64: 123456789012345678901", + }, + + // Number too large for int32 + { + in: "count: 1234567890123", + err: "line 1.7: invalid int32: 1234567890123", + }, + + // Number in hexadecimal + { + in: "count: 0x2beef", + out: &MyMessage{ + Count: Int32(0x2beef), + }, + }, + + // Number in octal + { + in: "count: 024601", + out: &MyMessage{ + Count: Int32(024601), + }, + }, + + // Floating point number with "f" suffix + { + in: "count: 4 others:< weight: 17.0f >", + out: &MyMessage{ + Count: Int32(4), + Others: []*OtherMessage{ + { + Weight: Float32(17), + }, + }, + }, + }, + + // Floating point positive infinity + { + in: "count: 4 bigfloat: inf", + out: &MyMessage{ + Count: Int32(4), + Bigfloat: Float64(math.Inf(1)), + }, + }, + + // Floating point negative infinity + { + in: "count: 4 bigfloat: -inf", + out: &MyMessage{ + Count: Int32(4), + Bigfloat: Float64(math.Inf(-1)), + }, + }, + + // Number too large for float32 + { + in: "others:< weight: 12345678901234567890123456789012345678901234567890 >", + err: "line 1.17: invalid float32: 12345678901234567890123456789012345678901234567890", + }, + + // Number posing as a quoted string + { + in: `inner: < host: 12 >` + "\n", + err: `line 1.15: invalid string: 12`, + }, + + // Quoted string posing as int32 + { + in: `count: "12"`, + err: `line 1.7: invalid int32: "12"`, + }, + + // Quoted string posing a float32 + { + in: `others:< weight: "17.4" >`, + err: `line 1.17: invalid float32: "17.4"`, + }, + + // Enum + { + in: `count:42 bikeshed: BLUE`, + out: &MyMessage{ + Count: Int32(42), + Bikeshed: MyMessage_BLUE.Enum(), + }, + }, + + // Repeated field + { + in: `count:42 pet: "horsey" pet:"bunny"`, + out: &MyMessage{ + Count: Int32(42), + Pet: []string{"horsey", "bunny"}, + }, + }, + + // Repeated message with/without colon and <>/{} + { + in: `count:42 others:{} others{} others:<> others:{}`, + out: &MyMessage{ + Count: Int32(42), + Others: []*OtherMessage{ + {}, + {}, + {}, + {}, + }, + }, + }, + + // Missing colon for inner message + { + in: `count:42 inner < host: "cauchy.syd" >`, + out: &MyMessage{ + Count: Int32(42), + Inner: &InnerMessage{ + Host: String("cauchy.syd"), + }, + }, + }, + + // Missing colon for string field + { + in: `name "Dave"`, + err: `line 1.5: expected ':', found "\"Dave\""`, + }, + + // Missing colon for int32 field + { + in: `count 42`, + err: `line 1.6: expected ':', found "42"`, + }, + + // Missing required field + { + in: ``, + err: `line 1.0: message testdata.MyMessage missing required field "count"`, + }, + + // Repeated non-repeated field + { + in: `name: "Rob" name: "Russ"`, + err: `line 1.12: non-repeated field "name" was repeated`, + }, + + // Group + { + in: `count: 17 SomeGroup { group_field: 12 }`, + out: &MyMessage{ + Count: Int32(17), + Somegroup: &MyMessage_SomeGroup{ + GroupField: Int32(12), + }, + }, + }, + + // Semicolon between fields + { + in: `count:3;name:"Calvin"`, + out: &MyMessage{ + Count: Int32(3), + Name: String("Calvin"), + }, + }, + // Comma between fields + { + in: `count:4,name:"Ezekiel"`, + out: &MyMessage{ + Count: Int32(4), + Name: String("Ezekiel"), + }, + }, + + // Extension + buildExtStructTest(`count: 42 [testdata.Ext.more]:`), + buildExtStructTest(`count: 42 [testdata.Ext.more] {data:"Hello, world!"}`), + buildExtDataTest(`count: 42 [testdata.Ext.text]:"Hello, world!" [testdata.Ext.number]:1729`), + buildExtRepStringTest(`count: 42 [testdata.greeting]:"bula" [testdata.greeting]:"hola"`), + + // Big all-in-one + { + in: "count:42 # Meaning\n" + + `name:"Dave" ` + + `quote:"\"I didn't want to go.\"" ` + + `pet:"bunny" ` + + `pet:"kitty" ` + + `pet:"horsey" ` + + `inner:<` + + ` host:"footrest.syd" ` + + ` port:7001 ` + + ` connected:true ` + + `> ` + + `others:<` + + ` key:3735928559 ` + + ` value:"\x01A\a\f" ` + + `> ` + + `others:<` + + " weight:58.9 # Atomic weight of Co\n" + + ` inner:<` + + ` host:"lesha.mtv" ` + + ` port:8002 ` + + ` >` + + `>`, + out: &MyMessage{ + Count: Int32(42), + Name: String("Dave"), + Quote: String(`"I didn't want to go."`), + Pet: []string{"bunny", "kitty", "horsey"}, + Inner: &InnerMessage{ + Host: String("footrest.syd"), + Port: Int32(7001), + Connected: Bool(true), + }, + Others: []*OtherMessage{ + { + Key: Int64(3735928559), + Value: []byte{0x1, 'A', '\a', '\f'}, + }, + { + Weight: Float32(58.9), + Inner: &InnerMessage{ + Host: String("lesha.mtv"), + Port: Int32(8002), + }, + }, + }, + }, + }, +} + +func TestUnmarshalText(t *testing.T) { + for i, test := range unMarshalTextTests { + pb := new(MyMessage) + err := UnmarshalText(test.in, pb) + if test.err == "" { + // We don't expect failure. + if err != nil { + t.Errorf("Test %d: Unexpected error: %v", i, err) + } else if !reflect.DeepEqual(pb, test.out) { + t.Errorf("Test %d: Incorrect populated \nHave: %v\nWant: %v", + i, pb, test.out) + } + } else { + // We do expect failure. + if err == nil { + t.Errorf("Test %d: Didn't get expected error: %v", i, test.err) + } else if err.Error() != test.err { + t.Errorf("Test %d: Incorrect error.\nHave: %v\nWant: %v", + i, err.Error(), test.err) + } + } + } +} + +func TestUnmarshalTextCustomMessage(t *testing.T) { + msg := &textMessage{} + if err := UnmarshalText("custom", msg); err != nil { + t.Errorf("Unexpected error from custom unmarshal: %v", err) + } + if UnmarshalText("not custom", msg) == nil { + t.Errorf("Didn't get expected error from custom unmarshal") + } +} + +// Regression test; this caused a panic. +func TestRepeatedEnum(t *testing.T) { + pb := new(RepeatedEnum) + if err := UnmarshalText("color: RED", pb); err != nil { + t.Fatal(err) + } + exp := &RepeatedEnum{ + Color: []RepeatedEnum_Color{RepeatedEnum_RED}, + } + if !Equal(pb, exp) { + t.Errorf("Incorrect populated \nHave: %v\nWant: %v", pb, exp) + } +} + +var benchInput string + +func init() { + benchInput = "count: 4\n" + for i := 0; i < 1000; i++ { + benchInput += "pet: \"fido\"\n" + } + + // Check it is valid input. + pb := new(MyMessage) + err := UnmarshalText(benchInput, pb) + if err != nil { + panic("Bad benchmark input: " + err.Error()) + } +} + +func BenchmarkUnmarshalText(b *testing.B) { + pb := new(MyMessage) + for i := 0; i < b.N; i++ { + UnmarshalText(benchInput, pb) + } + b.SetBytes(int64(len(benchInput))) +} diff --git a/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_test.go b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_test.go new file mode 100644 index 0000000000..224a484e50 --- /dev/null +++ b/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto/text_test.go @@ -0,0 +1,408 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "bytes" + "errors" + "io/ioutil" + "math" + "strings" + "testing" + + "code.google.com/p/goprotobuf/proto" + + pb "./testdata" +) + +// textMessage implements the methods that allow it to marshal and unmarshal +// itself as text. +type textMessage struct { +} + +func (*textMessage) MarshalText() ([]byte, error) { + return []byte("custom"), nil +} + +func (*textMessage) UnmarshalText(bytes []byte) error { + if string(bytes) != "custom" { + return errors.New("expected 'custom'") + } + return nil +} + +func (*textMessage) Reset() {} +func (*textMessage) String() string { return "" } +func (*textMessage) ProtoMessage() {} + +func newTestMessage() *pb.MyMessage { + msg := &pb.MyMessage{ + Count: proto.Int32(42), + Name: proto.String("Dave"), + Quote: proto.String(`"I didn't want to go."`), + Pet: []string{"bunny", "kitty", "horsey"}, + Inner: &pb.InnerMessage{ + Host: proto.String("footrest.syd"), + Port: proto.Int32(7001), + Connected: proto.Bool(true), + }, + Others: []*pb.OtherMessage{ + { + Key: proto.Int64(0xdeadbeef), + Value: []byte{1, 65, 7, 12}, + }, + { + Weight: proto.Float32(6.022), + Inner: &pb.InnerMessage{ + Host: proto.String("lesha.mtv"), + Port: proto.Int32(8002), + }, + }, + }, + Bikeshed: pb.MyMessage_BLUE.Enum(), + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(8), + }, + // One normally wouldn't do this. + // This is an undeclared tag 13, as a varint (wire type 0) with value 4. + XXX_unrecognized: []byte{13<<3 | 0, 4}, + } + ext := &pb.Ext{ + Data: proto.String("Big gobs for big rats"), + } + if err := proto.SetExtension(msg, pb.E_Ext_More, ext); err != nil { + panic(err) + } + greetings := []string{"adg", "easy", "cow"} + if err := proto.SetExtension(msg, pb.E_Greeting, greetings); err != nil { + panic(err) + } + + // Add an unknown extension. We marshal a pb.Ext, and fake the ID. + b, err := proto.Marshal(&pb.Ext{Data: proto.String("3G skiing")}) + if err != nil { + panic(err) + } + b = append(proto.EncodeVarint(201<<3|proto.WireBytes), b...) + proto.SetRawExtension(msg, 201, b) + + // Extensions can be plain fields, too, so let's test that. + b = append(proto.EncodeVarint(202<<3|proto.WireVarint), 19) + proto.SetRawExtension(msg, 202, b) + + return msg +} + +const text = `count: 42 +name: "Dave" +quote: "\"I didn't want to go.\"" +pet: "bunny" +pet: "kitty" +pet: "horsey" +inner: < + host: "footrest.syd" + port: 7001 + connected: true +> +others: < + key: 3735928559 + value: "\001A\007\014" +> +others: < + weight: 6.022 + inner: < + host: "lesha.mtv" + port: 8002 + > +> +bikeshed: BLUE +SomeGroup { + group_field: 8 +} +/* 2 unknown bytes */ +13: 4 +[testdata.Ext.more]: < + data: "Big gobs for big rats" +> +[testdata.greeting]: "adg" +[testdata.greeting]: "easy" +[testdata.greeting]: "cow" +/* 13 unknown bytes */ +201: "\t3G skiing" +/* 3 unknown bytes */ +202: 19 +` + +func TestMarshalText(t *testing.T) { + buf := new(bytes.Buffer) + if err := proto.MarshalText(buf, newTestMessage()); err != nil { + t.Fatalf("proto.MarshalText: %v", err) + } + s := buf.String() + if s != text { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, text) + } +} + +func TestMarshalTextCustomMessage(t *testing.T) { + buf := new(bytes.Buffer) + if err := proto.MarshalText(buf, &textMessage{}); err != nil { + t.Fatalf("proto.MarshalText: %v", err) + } + s := buf.String() + if s != "custom" { + t.Errorf("Got %q, expected %q", s, "custom") + } +} +func TestMarshalTextNil(t *testing.T) { + want := "" + tests := []proto.Message{nil, (*pb.MyMessage)(nil)} + for i, test := range tests { + buf := new(bytes.Buffer) + if err := proto.MarshalText(buf, test); err != nil { + t.Fatal(err) + } + if got := buf.String(); got != want { + t.Errorf("%d: got %q want %q", i, got, want) + } + } +} + +func TestMarshalTextUnknownEnum(t *testing.T) { + // The Color enum only specifies values 0-2. + m := &pb.MyMessage{Bikeshed: pb.MyMessage_Color(3).Enum()} + got := m.String() + const want = `bikeshed:3 ` + if got != want { + t.Errorf("\n got %q\nwant %q", got, want) + } +} + +func BenchmarkMarshalTextBuffered(b *testing.B) { + buf := new(bytes.Buffer) + m := newTestMessage() + for i := 0; i < b.N; i++ { + buf.Reset() + proto.MarshalText(buf, m) + } +} + +func BenchmarkMarshalTextUnbuffered(b *testing.B) { + w := ioutil.Discard + m := newTestMessage() + for i := 0; i < b.N; i++ { + proto.MarshalText(w, m) + } +} + +func compact(src string) string { + // s/[ \n]+/ /g; s/ $//; + dst := make([]byte, len(src)) + space, comment := false, false + j := 0 + for i := 0; i < len(src); i++ { + if strings.HasPrefix(src[i:], "/*") { + comment = true + i++ + continue + } + if comment && strings.HasPrefix(src[i:], "*/") { + comment = false + i++ + continue + } + if comment { + continue + } + c := src[i] + if c == ' ' || c == '\n' { + space = true + continue + } + if j > 0 && (dst[j-1] == ':' || dst[j-1] == '<' || dst[j-1] == '{') { + space = false + } + if c == '{' { + space = false + } + if space { + dst[j] = ' ' + j++ + space = false + } + dst[j] = c + j++ + } + if space { + dst[j] = ' ' + j++ + } + return string(dst[0:j]) +} + +var compactText = compact(text) + +func TestCompactText(t *testing.T) { + s := proto.CompactTextString(newTestMessage()) + if s != compactText { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v\n===\n", s, compactText) + } +} + +func TestStringEscaping(t *testing.T) { + testCases := []struct { + in *pb.Strings + out string + }{ + { + // Test data from C++ test (TextFormatTest.StringEscape). + // Single divergence: we don't escape apostrophes. + &pb.Strings{StringField: proto.String("\"A string with ' characters \n and \r newlines and \t tabs and \001 slashes \\ and multiple spaces")}, + "string_field: \"\\\"A string with ' characters \\n and \\r newlines and \\t tabs and \\001 slashes \\\\ and multiple spaces\"\n", + }, + { + // Test data from the same C++ test. + &pb.Strings{StringField: proto.String("\350\260\267\346\255\214")}, + "string_field: \"\\350\\260\\267\\346\\255\\214\"\n", + }, + { + // Some UTF-8. + &pb.Strings{StringField: proto.String("\x00\x01\xff\x81")}, + `string_field: "\000\001\377\201"` + "\n", + }, + } + + for i, tc := range testCases { + var buf bytes.Buffer + if err := proto.MarshalText(&buf, tc.in); err != nil { + t.Errorf("proto.MarsalText: %v", err) + continue + } + s := buf.String() + if s != tc.out { + t.Errorf("#%d: Got:\n%s\nExpected:\n%s\n", i, s, tc.out) + continue + } + + // Check round-trip. + pb := new(pb.Strings) + if err := proto.UnmarshalText(s, pb); err != nil { + t.Errorf("#%d: UnmarshalText: %v", i, err) + continue + } + if !proto.Equal(pb, tc.in) { + t.Errorf("#%d: Round-trip failed:\nstart: %v\n end: %v", i, tc.in, pb) + } + } +} + +// A limitedWriter accepts some output before it fails. +// This is a proxy for something like a nearly-full or imminently-failing disk, +// or a network connection that is about to die. +type limitedWriter struct { + b bytes.Buffer + limit int +} + +var outOfSpace = errors.New("proto: insufficient space") + +func (w *limitedWriter) Write(p []byte) (n int, err error) { + var avail = w.limit - w.b.Len() + if avail <= 0 { + return 0, outOfSpace + } + if len(p) <= avail { + return w.b.Write(p) + } + n, _ = w.b.Write(p[:avail]) + return n, outOfSpace +} + +func TestMarshalTextFailing(t *testing.T) { + // Try lots of different sizes to exercise more error code-paths. + for lim := 0; lim < len(text); lim++ { + buf := new(limitedWriter) + buf.limit = lim + err := proto.MarshalText(buf, newTestMessage()) + // We expect a certain error, but also some partial results in the buffer. + if err != outOfSpace { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", err, outOfSpace) + } + s := buf.b.String() + x := text[:buf.limit] + if s != x { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, x) + } + } +} + +func TestFloats(t *testing.T) { + tests := []struct { + f float64 + want string + }{ + {0, "0"}, + {4.7, "4.7"}, + {math.Inf(1), "inf"}, + {math.Inf(-1), "-inf"}, + {math.NaN(), "nan"}, + } + for _, test := range tests { + msg := &pb.FloatingPoint{F: &test.f} + got := strings.TrimSpace(msg.String()) + want := `f:` + test.want + if got != want { + t.Errorf("f=%f: got %q, want %q", test.f, got, want) + } + } +} + +func TestRepeatedNilText(t *testing.T) { + m := &pb.MessageList{ + Message: []*pb.MessageList_Message{ + nil, + &pb.MessageList_Message{ + Name: proto.String("Horse"), + }, + nil, + }, + } + want := `Message +Message { + name: "Horse" +} +Message +` + if s := proto.MarshalTextString(m); s != want { + t.Errorf(" got: %s\nwant: %s", s, want) + } +} diff --git a/Godeps/_workspace/src/github.com/golang/glog/LICENSE b/Godeps/_workspace/src/github.com/golang/glog/LICENSE new file mode 100644 index 0000000000..37ec93a14f --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/glog/LICENSE @@ -0,0 +1,191 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of +this License; and +You must cause any modified files to carry prominent notices stating that You +changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same "printed page" as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Godeps/_workspace/src/github.com/golang/glog/README b/Godeps/_workspace/src/github.com/golang/glog/README new file mode 100644 index 0000000000..5f9c11485e --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/glog/README @@ -0,0 +1,44 @@ +glog +==== + +Leveled execution logs for Go. + +This is an efficient pure Go implementation of leveled logs in the +manner of the open source C++ package + http://code.google.com/p/google-glog + +By binding methods to booleans it is possible to use the log package +without paying the expense of evaluating the arguments to the log. +Through the -vmodule flag, the package also provides fine-grained +control over logging at the file level. + +The comment from glog.go introduces the ideas: + + Package glog implements logging analogous to the Google-internal + C++ INFO/ERROR/V setup. It provides functions Info, Warning, + Error, Fatal, plus formatting variants such as Infof. It + also provides V-style logging controlled by the -v and + -vmodule=file=2 flags. + + Basic examples: + + glog.Info("Prepare to repel boarders") + + glog.Fatalf("Initialization failed: %s", err) + + See the documentation for the V function for an explanation + of these examples: + + if glog.V(2) { + glog.Info("Starting transaction...") + } + + glog.V(2).Infoln("Processed", nItems, "elements") + + +The repository contains an open source version of the log package +used inside Google. The master copy of the source lives inside +Google, not here. The code in this repo is for export only and is not itself +under development. Feature requests will be ignored. + +Send bug reports to golang-nuts@googlegroups.com. diff --git a/Godeps/_workspace/src/github.com/golang/glog/glog.go b/Godeps/_workspace/src/github.com/golang/glog/glog.go new file mode 100644 index 0000000000..3e63fffd5e --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/glog/glog.go @@ -0,0 +1,1177 @@ +// Go support for leveled logs, analogous to https://code.google.com/p/google-glog/ +// +// Copyright 2013 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package glog implements logging analogous to the Google-internal C++ INFO/ERROR/V setup. +// It provides functions Info, Warning, Error, Fatal, plus formatting variants such as +// Infof. It also provides V-style logging controlled by the -v and -vmodule=file=2 flags. +// +// Basic examples: +// +// glog.Info("Prepare to repel boarders") +// +// glog.Fatalf("Initialization failed: %s", err) +// +// See the documentation for the V function for an explanation of these examples: +// +// if glog.V(2) { +// glog.Info("Starting transaction...") +// } +// +// glog.V(2).Infoln("Processed", nItems, "elements") +// +// Log output is buffered and written periodically using Flush. Programs +// should call Flush before exiting to guarantee all log output is written. +// +// By default, all log statements write to files in a temporary directory. +// This package provides several flags that modify this behavior. +// As a result, flag.Parse must be called before any logging is done. +// +// -logtostderr=false +// Logs are written to standard error instead of to files. +// -alsologtostderr=false +// Logs are written to standard error as well as to files. +// -stderrthreshold=ERROR +// Log events at or above this severity are logged to standard +// error as well as to files. +// -log_dir="" +// Log files will be written to this directory instead of the +// default temporary directory. +// +// Other flags provide aids to debugging. +// +// -log_backtrace_at="" +// When set to a file and line number holding a logging statement, +// such as +// -log_backtrace_at=gopherflakes.go:234 +// a stack trace will be written to the Info log whenever execution +// hits that statement. (Unlike with -vmodule, the ".go" must be +// present.) +// -v=0 +// Enable V-leveled logging at the specified level. +// -vmodule="" +// The syntax of the argument is a comma-separated list of pattern=N, +// where pattern is a literal file name (minus the ".go" suffix) or +// "glob" pattern and N is a V level. For instance, +// -vmodule=gopher*=3 +// sets the V level to 3 in all Go files whose names begin "gopher". +// +package glog + +import ( + "bufio" + "bytes" + "errors" + "flag" + "fmt" + "io" + stdLog "log" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// severity identifies the sort of log: info, warning etc. It also implements +// the flag.Value interface. The -stderrthreshold flag is of type severity and +// should be modified only through the flag.Value interface. The values match +// the corresponding constants in C++. +type severity int32 // sync/atomic int32 + +// These constants identify the log levels in order of increasing severity. +// A message written to a high-severity log file is also written to each +// lower-severity log file. +const ( + infoLog severity = iota + warningLog + errorLog + fatalLog + numSeverity = 4 +) + +const severityChar = "IWEF" + +var severityName = []string{ + infoLog: "INFO", + warningLog: "WARNING", + errorLog: "ERROR", + fatalLog: "FATAL", +} + +// get returns the value of the severity. +func (s *severity) get() severity { + return severity(atomic.LoadInt32((*int32)(s))) +} + +// set sets the value of the severity. +func (s *severity) set(val severity) { + atomic.StoreInt32((*int32)(s), int32(val)) +} + +// String is part of the flag.Value interface. +func (s *severity) String() string { + return strconv.FormatInt(int64(*s), 10) +} + +// Get is part of the flag.Value interface. +func (s *severity) Get() interface{} { + return *s +} + +// Set is part of the flag.Value interface. +func (s *severity) Set(value string) error { + var threshold severity + // Is it a known name? + if v, ok := severityByName(value); ok { + threshold = v + } else { + v, err := strconv.Atoi(value) + if err != nil { + return err + } + threshold = severity(v) + } + logging.stderrThreshold.set(threshold) + return nil +} + +func severityByName(s string) (severity, bool) { + s = strings.ToUpper(s) + for i, name := range severityName { + if name == s { + return severity(i), true + } + } + return 0, false +} + +// OutputStats tracks the number of output lines and bytes written. +type OutputStats struct { + lines int64 + bytes int64 +} + +// Lines returns the number of lines written. +func (s *OutputStats) Lines() int64 { + return atomic.LoadInt64(&s.lines) +} + +// Bytes returns the number of bytes written. +func (s *OutputStats) Bytes() int64 { + return atomic.LoadInt64(&s.bytes) +} + +// Stats tracks the number of lines of output and number of bytes +// per severity level. Values must be read with atomic.LoadInt64. +var Stats struct { + Info, Warning, Error OutputStats +} + +var severityStats = [numSeverity]*OutputStats{ + infoLog: &Stats.Info, + warningLog: &Stats.Warning, + errorLog: &Stats.Error, +} + +// Level is exported because it appears in the arguments to V and is +// the type of the v flag, which can be set programmatically. +// It's a distinct type because we want to discriminate it from logType. +// Variables of type level are only changed under logging.mu. +// The -v flag is read only with atomic ops, so the state of the logging +// module is consistent. + +// Level is treated as a sync/atomic int32. + +// Level specifies a level of verbosity for V logs. *Level implements +// flag.Value; the -v flag is of type Level and should be modified +// only through the flag.Value interface. +type Level int32 + +// get returns the value of the Level. +func (l *Level) get() Level { + return Level(atomic.LoadInt32((*int32)(l))) +} + +// set sets the value of the Level. +func (l *Level) set(val Level) { + atomic.StoreInt32((*int32)(l), int32(val)) +} + +// String is part of the flag.Value interface. +func (l *Level) String() string { + return strconv.FormatInt(int64(*l), 10) +} + +// Get is part of the flag.Value interface. +func (l *Level) Get() interface{} { + return *l +} + +// Set is part of the flag.Value interface. +func (l *Level) Set(value string) error { + v, err := strconv.Atoi(value) + if err != nil { + return err + } + logging.mu.Lock() + defer logging.mu.Unlock() + logging.setVState(Level(v), logging.vmodule.filter, false) + return nil +} + +// moduleSpec represents the setting of the -vmodule flag. +type moduleSpec struct { + filter []modulePat +} + +// modulePat contains a filter for the -vmodule flag. +// It holds a verbosity level and a file pattern to match. +type modulePat struct { + pattern string + literal bool // The pattern is a literal string + level Level +} + +// match reports whether the file matches the pattern. It uses a string +// comparison if the pattern contains no metacharacters. +func (m *modulePat) match(file string) bool { + if m.literal { + return file == m.pattern + } + match, _ := filepath.Match(m.pattern, file) + return match +} + +func (m *moduleSpec) String() string { + // Lock because the type is not atomic. TODO: clean this up. + logging.mu.Lock() + defer logging.mu.Unlock() + var b bytes.Buffer + for i, f := range m.filter { + if i > 0 { + b.WriteRune(',') + } + fmt.Fprintf(&b, "%s=%d", f.pattern, f.level) + } + return b.String() +} + +// Get is part of the (Go 1.2) flag.Getter interface. It always returns nil for this flag type since the +// struct is not exported. +func (m *moduleSpec) Get() interface{} { + return nil +} + +var errVmoduleSyntax = errors.New("syntax error: expect comma-separated list of filename=N") + +// Syntax: -vmodule=recordio=2,file=1,gfs*=3 +func (m *moduleSpec) Set(value string) error { + var filter []modulePat + for _, pat := range strings.Split(value, ",") { + if len(pat) == 0 { + // Empty strings such as from a trailing comma can be ignored. + continue + } + patLev := strings.Split(pat, "=") + if len(patLev) != 2 || len(patLev[0]) == 0 || len(patLev[1]) == 0 { + return errVmoduleSyntax + } + pattern := patLev[0] + v, err := strconv.Atoi(patLev[1]) + if err != nil { + return errors.New("syntax error: expect comma-separated list of filename=N") + } + if v < 0 { + return errors.New("negative value for vmodule level") + } + if v == 0 { + continue // Ignore. It's harmless but no point in paying the overhead. + } + // TODO: check syntax of filter? + filter = append(filter, modulePat{pattern, isLiteral(pattern), Level(v)}) + } + logging.mu.Lock() + defer logging.mu.Unlock() + logging.setVState(logging.verbosity, filter, true) + return nil +} + +// isLiteral reports whether the pattern is a literal string, that is, has no metacharacters +// that require filepath.Match to be called to match the pattern. +func isLiteral(pattern string) bool { + return !strings.ContainsAny(pattern, `\*?[]`) +} + +// traceLocation represents the setting of the -log_backtrace_at flag. +type traceLocation struct { + file string + line int +} + +// isSet reports whether the trace location has been specified. +// logging.mu is held. +func (t *traceLocation) isSet() bool { + return t.line > 0 +} + +// match reports whether the specified file and line matches the trace location. +// The argument file name is the full path, not the basename specified in the flag. +// logging.mu is held. +func (t *traceLocation) match(file string, line int) bool { + if t.line != line { + return false + } + if i := strings.LastIndex(file, "/"); i >= 0 { + file = file[i+1:] + } + return t.file == file +} + +func (t *traceLocation) String() string { + // Lock because the type is not atomic. TODO: clean this up. + logging.mu.Lock() + defer logging.mu.Unlock() + return fmt.Sprintf("%s:%d", t.file, t.line) +} + +// Get is part of the (Go 1.2) flag.Getter interface. It always returns nil for this flag type since the +// struct is not exported +func (t *traceLocation) Get() interface{} { + return nil +} + +var errTraceSyntax = errors.New("syntax error: expect file.go:234") + +// Syntax: -log_backtrace_at=gopherflakes.go:234 +// Note that unlike vmodule the file extension is included here. +func (t *traceLocation) Set(value string) error { + if value == "" { + // Unset. + t.line = 0 + t.file = "" + } + fields := strings.Split(value, ":") + if len(fields) != 2 { + return errTraceSyntax + } + file, line := fields[0], fields[1] + if !strings.Contains(file, ".") { + return errTraceSyntax + } + v, err := strconv.Atoi(line) + if err != nil { + return errTraceSyntax + } + if v <= 0 { + return errors.New("negative or zero value for level") + } + logging.mu.Lock() + defer logging.mu.Unlock() + t.line = v + t.file = file + return nil +} + +// flushSyncWriter is the interface satisfied by logging destinations. +type flushSyncWriter interface { + Flush() error + Sync() error + io.Writer +} + +func init() { + flag.BoolVar(&logging.toStderr, "logtostderr", false, "log to standard error instead of files") + flag.BoolVar(&logging.alsoToStderr, "alsologtostderr", false, "log to standard error as well as files") + flag.Var(&logging.verbosity, "v", "log level for V logs") + flag.Var(&logging.stderrThreshold, "stderrthreshold", "logs at or above this threshold go to stderr") + flag.Var(&logging.vmodule, "vmodule", "comma-separated list of pattern=N settings for file-filtered logging") + flag.Var(&logging.traceLocation, "log_backtrace_at", "when logging hits line file:N, emit a stack trace") + + // Default stderrThreshold is ERROR. + logging.stderrThreshold = errorLog + + logging.setVState(0, nil, false) + go logging.flushDaemon() +} + +// Flush flushes all pending log I/O. +func Flush() { + logging.lockAndFlushAll() +} + +// loggingT collects all the global state of the logging setup. +type loggingT struct { + // Boolean flags. Not handled atomically because the flag.Value interface + // does not let us avoid the =true, and that shorthand is necessary for + // compatibility. TODO: does this matter enough to fix? Seems unlikely. + toStderr bool // The -logtostderr flag. + alsoToStderr bool // The -alsologtostderr flag. + + // Level flag. Handled atomically. + stderrThreshold severity // The -stderrthreshold flag. + + // freeList is a list of byte buffers, maintained under freeListMu. + freeList *buffer + // freeListMu maintains the free list. It is separate from the main mutex + // so buffers can be grabbed and printed to without holding the main lock, + // for better parallelization. + freeListMu sync.Mutex + + // mu protects the remaining elements of this structure and is + // used to synchronize logging. + mu sync.Mutex + // file holds writer for each of the log types. + file [numSeverity]flushSyncWriter + // pcs is used in V to avoid an allocation when computing the caller's PC. + pcs [1]uintptr + // vmap is a cache of the V Level for each V() call site, identified by PC. + // It is wiped whenever the vmodule flag changes state. + vmap map[uintptr]Level + // filterLength stores the length of the vmodule filter chain. If greater + // than zero, it means vmodule is enabled. It may be read safely + // using sync.LoadInt32, but is only modified under mu. + filterLength int32 + // traceLocation is the state of the -log_backtrace_at flag. + traceLocation traceLocation + // These flags are modified only under lock, although verbosity may be fetched + // safely using atomic.LoadInt32. + vmodule moduleSpec // The state of the -vmodule flag. + verbosity Level // V logging level, the value of the -v flag/ +} + +// buffer holds a byte Buffer for reuse. The zero value is ready for use. +type buffer struct { + bytes.Buffer + tmp [64]byte // temporary byte array for creating headers. + next *buffer +} + +var logging loggingT + +// setVState sets a consistent state for V logging. +// l.mu is held. +func (l *loggingT) setVState(verbosity Level, filter []modulePat, setFilter bool) { + // Turn verbosity off so V will not fire while we are in transition. + logging.verbosity.set(0) + // Ditto for filter length. + atomic.StoreInt32(&logging.filterLength, 0) + + // Set the new filters and wipe the pc->Level map if the filter has changed. + if setFilter { + logging.vmodule.filter = filter + logging.vmap = make(map[uintptr]Level) + } + + // Things are consistent now, so enable filtering and verbosity. + // They are enabled in order opposite to that in V. + atomic.StoreInt32(&logging.filterLength, int32(len(filter))) + logging.verbosity.set(verbosity) +} + +// getBuffer returns a new, ready-to-use buffer. +func (l *loggingT) getBuffer() *buffer { + l.freeListMu.Lock() + b := l.freeList + if b != nil { + l.freeList = b.next + } + l.freeListMu.Unlock() + if b == nil { + b = new(buffer) + } else { + b.next = nil + b.Reset() + } + return b +} + +// putBuffer returns a buffer to the free list. +func (l *loggingT) putBuffer(b *buffer) { + if b.Len() >= 256 { + // Let big buffers die a natural death. + return + } + l.freeListMu.Lock() + b.next = l.freeList + l.freeList = b + l.freeListMu.Unlock() +} + +var timeNow = time.Now // Stubbed out for testing. + +/* +header formats a log header as defined by the C++ implementation. +It returns a buffer containing the formatted header and the user's file and line number. +The depth specifies how many stack frames above lives the source line to be identified in the log message. + +Log lines have this form: + Lmmdd hh:mm:ss.uuuuuu threadid file:line] msg... +where the fields are defined as follows: + L A single character, representing the log level (eg 'I' for INFO) + mm The month (zero padded; ie May is '05') + dd The day (zero padded) + hh:mm:ss.uuuuuu Time in hours, minutes and fractional seconds + threadid The space-padded thread ID as returned by GetTID() + file The file name + line The line number + msg The user-supplied message +*/ +func (l *loggingT) header(s severity, depth int) (*buffer, string, int) { + _, file, line, ok := runtime.Caller(3 + depth) + if !ok { + file = "???" + line = 1 + } else { + slash := strings.LastIndex(file, "/") + if slash >= 0 { + file = file[slash+1:] + } + } + return l.formatHeader(s, file, line), file, line +} + +// formatHeader formats a log header using the provided file name and line number. +func (l *loggingT) formatHeader(s severity, file string, line int) *buffer { + now := timeNow() + if line < 0 { + line = 0 // not a real line number, but acceptable to someDigits + } + if s > fatalLog { + s = infoLog // for safety. + } + buf := l.getBuffer() + + // Avoid Fprintf, for speed. The format is so simple that we can do it quickly by hand. + // It's worth about 3X. Fprintf is hard. + _, month, day := now.Date() + hour, minute, second := now.Clock() + // Lmmdd hh:mm:ss.uuuuuu threadid file:line] + buf.tmp[0] = severityChar[s] + buf.twoDigits(1, int(month)) + buf.twoDigits(3, day) + buf.tmp[5] = ' ' + buf.twoDigits(6, hour) + buf.tmp[8] = ':' + buf.twoDigits(9, minute) + buf.tmp[11] = ':' + buf.twoDigits(12, second) + buf.tmp[14] = '.' + buf.nDigits(6, 15, now.Nanosecond()/1000, '0') + buf.tmp[21] = ' ' + buf.nDigits(7, 22, pid, ' ') // TODO: should be TID + buf.tmp[29] = ' ' + buf.Write(buf.tmp[:30]) + buf.WriteString(file) + buf.tmp[0] = ':' + n := buf.someDigits(1, line) + buf.tmp[n+1] = ']' + buf.tmp[n+2] = ' ' + buf.Write(buf.tmp[:n+3]) + return buf +} + +// Some custom tiny helper functions to print the log header efficiently. + +const digits = "0123456789" + +// twoDigits formats a zero-prefixed two-digit integer at buf.tmp[i]. +func (buf *buffer) twoDigits(i, d int) { + buf.tmp[i+1] = digits[d%10] + d /= 10 + buf.tmp[i] = digits[d%10] +} + +// nDigits formats an n-digit integer at buf.tmp[i], +// padding with pad on the left. +// It assumes d >= 0. +func (buf *buffer) nDigits(n, i, d int, pad byte) { + j := n - 1 + for ; j >= 0 && d > 0; j-- { + buf.tmp[i+j] = digits[d%10] + d /= 10 + } + for ; j >= 0; j-- { + buf.tmp[i+j] = pad + } +} + +// someDigits formats a zero-prefixed variable-width integer at buf.tmp[i]. +func (buf *buffer) someDigits(i, d int) int { + // Print into the top, then copy down. We know there's space for at least + // a 10-digit number. + j := len(buf.tmp) + for { + j-- + buf.tmp[j] = digits[d%10] + d /= 10 + if d == 0 { + break + } + } + return copy(buf.tmp[i:], buf.tmp[j:]) +} + +func (l *loggingT) println(s severity, args ...interface{}) { + buf, file, line := l.header(s, 0) + fmt.Fprintln(buf, args...) + l.output(s, buf, file, line, false) +} + +func (l *loggingT) print(s severity, args ...interface{}) { + l.printDepth(s, 1, args...) +} + +func (l *loggingT) printDepth(s severity, depth int, args ...interface{}) { + buf, file, line := l.header(s, depth) + fmt.Fprint(buf, args...) + if buf.Bytes()[buf.Len()-1] != '\n' { + buf.WriteByte('\n') + } + l.output(s, buf, file, line, false) +} + +func (l *loggingT) printf(s severity, format string, args ...interface{}) { + buf, file, line := l.header(s, 0) + fmt.Fprintf(buf, format, args...) + if buf.Bytes()[buf.Len()-1] != '\n' { + buf.WriteByte('\n') + } + l.output(s, buf, file, line, false) +} + +// printWithFileLine behaves like print but uses the provided file and line number. If +// alsoLogToStderr is true, the log message always appears on standard error; it +// will also appear in the log file unless --logtostderr is set. +func (l *loggingT) printWithFileLine(s severity, file string, line int, alsoToStderr bool, args ...interface{}) { + buf := l.formatHeader(s, file, line) + fmt.Fprint(buf, args...) + if buf.Bytes()[buf.Len()-1] != '\n' { + buf.WriteByte('\n') + } + l.output(s, buf, file, line, alsoToStderr) +} + +// output writes the data to the log files and releases the buffer. +func (l *loggingT) output(s severity, buf *buffer, file string, line int, alsoToStderr bool) { + l.mu.Lock() + if l.traceLocation.isSet() { + if l.traceLocation.match(file, line) { + buf.Write(stacks(false)) + } + } + data := buf.Bytes() + if l.toStderr { + os.Stderr.Write(data) + } else { + if alsoToStderr || l.alsoToStderr || s >= l.stderrThreshold.get() { + os.Stderr.Write(data) + } + if l.file[s] == nil { + if err := l.createFiles(s); err != nil { + os.Stderr.Write(data) // Make sure the message appears somewhere. + l.exit(err) + } + } + switch s { + case fatalLog: + l.file[fatalLog].Write(data) + fallthrough + case errorLog: + l.file[errorLog].Write(data) + fallthrough + case warningLog: + l.file[warningLog].Write(data) + fallthrough + case infoLog: + l.file[infoLog].Write(data) + } + } + if s == fatalLog { + // If we got here via Exit rather than Fatal, print no stacks. + if atomic.LoadUint32(&fatalNoStacks) > 0 { + l.mu.Unlock() + timeoutFlush(10 * time.Second) + os.Exit(1) + } + // Dump all goroutine stacks before exiting. + // First, make sure we see the trace for the current goroutine on standard error. + // If -logtostderr has been specified, the loop below will do that anyway + // as the first stack in the full dump. + if !l.toStderr { + os.Stderr.Write(stacks(false)) + } + // Write the stack trace for all goroutines to the files. + trace := stacks(true) + logExitFunc = func(error) {} // If we get a write error, we'll still exit below. + for log := fatalLog; log >= infoLog; log-- { + if f := l.file[log]; f != nil { // Can be nil if -logtostderr is set. + f.Write(trace) + } + } + l.mu.Unlock() + timeoutFlush(10 * time.Second) + os.Exit(255) // C++ uses -1, which is silly because it's anded with 255 anyway. + } + l.putBuffer(buf) + l.mu.Unlock() + if stats := severityStats[s]; stats != nil { + atomic.AddInt64(&stats.lines, 1) + atomic.AddInt64(&stats.bytes, int64(len(data))) + } +} + +// timeoutFlush calls Flush and returns when it completes or after timeout +// elapses, whichever happens first. This is needed because the hooks invoked +// by Flush may deadlock when glog.Fatal is called from a hook that holds +// a lock. +func timeoutFlush(timeout time.Duration) { + done := make(chan bool, 1) + go func() { + Flush() // calls logging.lockAndFlushAll() + done <- true + }() + select { + case <-done: + case <-time.After(timeout): + fmt.Fprintln(os.Stderr, "glog: Flush took longer than", timeout) + } +} + +// stacks is a wrapper for runtime.Stack that attempts to recover the data for all goroutines. +func stacks(all bool) []byte { + // We don't know how big the traces are, so grow a few times if they don't fit. Start large, though. + n := 10000 + if all { + n = 100000 + } + var trace []byte + for i := 0; i < 5; i++ { + trace = make([]byte, n) + nbytes := runtime.Stack(trace, all) + if nbytes < len(trace) { + return trace[:nbytes] + } + n *= 2 + } + return trace +} + +// logExitFunc provides a simple mechanism to override the default behavior +// of exiting on error. Used in testing and to guarantee we reach a required exit +// for fatal logs. Instead, exit could be a function rather than a method but that +// would make its use clumsier. +var logExitFunc func(error) + +// exit is called if there is trouble creating or writing log files. +// It flushes the logs and exits the program; there's no point in hanging around. +// l.mu is held. +func (l *loggingT) exit(err error) { + fmt.Fprintf(os.Stderr, "log: exiting because of error: %s\n", err) + // If logExitFunc is set, we do that instead of exiting. + if logExitFunc != nil { + logExitFunc(err) + return + } + l.flushAll() + os.Exit(2) +} + +// syncBuffer joins a bufio.Writer to its underlying file, providing access to the +// file's Sync method and providing a wrapper for the Write method that provides log +// file rotation. There are conflicting methods, so the file cannot be embedded. +// l.mu is held for all its methods. +type syncBuffer struct { + logger *loggingT + *bufio.Writer + file *os.File + sev severity + nbytes uint64 // The number of bytes written to this file +} + +func (sb *syncBuffer) Sync() error { + return sb.file.Sync() +} + +func (sb *syncBuffer) Write(p []byte) (n int, err error) { + if sb.nbytes+uint64(len(p)) >= MaxSize { + if err := sb.rotateFile(time.Now()); err != nil { + sb.logger.exit(err) + } + } + n, err = sb.Writer.Write(p) + sb.nbytes += uint64(n) + if err != nil { + sb.logger.exit(err) + } + return +} + +// rotateFile closes the syncBuffer's file and starts a new one. +func (sb *syncBuffer) rotateFile(now time.Time) error { + if sb.file != nil { + sb.Flush() + sb.file.Close() + } + var err error + sb.file, _, err = create(severityName[sb.sev], now) + sb.nbytes = 0 + if err != nil { + return err + } + + sb.Writer = bufio.NewWriterSize(sb.file, bufferSize) + + // Write header. + var buf bytes.Buffer + fmt.Fprintf(&buf, "Log file created at: %s\n", now.Format("2006/01/02 15:04:05")) + fmt.Fprintf(&buf, "Running on machine: %s\n", host) + fmt.Fprintf(&buf, "Binary: Built with %s %s for %s/%s\n", runtime.Compiler, runtime.Version(), runtime.GOOS, runtime.GOARCH) + fmt.Fprintf(&buf, "Log line format: [IWEF]mmdd hh:mm:ss.uuuuuu threadid file:line] msg\n") + n, err := sb.file.Write(buf.Bytes()) + sb.nbytes += uint64(n) + return err +} + +// bufferSize sizes the buffer associated with each log file. It's large +// so that log records can accumulate without the logging thread blocking +// on disk I/O. The flushDaemon will block instead. +const bufferSize = 256 * 1024 + +// createFiles creates all the log files for severity from sev down to infoLog. +// l.mu is held. +func (l *loggingT) createFiles(sev severity) error { + now := time.Now() + // Files are created in decreasing severity order, so as soon as we find one + // has already been created, we can stop. + for s := sev; s >= infoLog && l.file[s] == nil; s-- { + sb := &syncBuffer{ + logger: l, + sev: s, + } + if err := sb.rotateFile(now); err != nil { + return err + } + l.file[s] = sb + } + return nil +} + +const flushInterval = 30 * time.Second + +// flushDaemon periodically flushes the log file buffers. +func (l *loggingT) flushDaemon() { + for _ = range time.NewTicker(flushInterval).C { + l.lockAndFlushAll() + } +} + +// lockAndFlushAll is like flushAll but locks l.mu first. +func (l *loggingT) lockAndFlushAll() { + l.mu.Lock() + l.flushAll() + l.mu.Unlock() +} + +// flushAll flushes all the logs and attempts to "sync" their data to disk. +// l.mu is held. +func (l *loggingT) flushAll() { + // Flush from fatal down, in case there's trouble flushing. + for s := fatalLog; s >= infoLog; s-- { + file := l.file[s] + if file != nil { + file.Flush() // ignore error + file.Sync() // ignore error + } + } +} + +// CopyStandardLogTo arranges for messages written to the Go "log" package's +// default logs to also appear in the Google logs for the named and lower +// severities. Subsequent changes to the standard log's default output location +// or format may break this behavior. +// +// Valid names are "INFO", "WARNING", "ERROR", and "FATAL". If the name is not +// recognized, CopyStandardLogTo panics. +func CopyStandardLogTo(name string) { + sev, ok := severityByName(name) + if !ok { + panic(fmt.Sprintf("log.CopyStandardLogTo(%q): unrecognized severity name", name)) + } + // Set a log format that captures the user's file and line: + // d.go:23: message + stdLog.SetFlags(stdLog.Lshortfile) + stdLog.SetOutput(logBridge(sev)) +} + +// logBridge provides the Write method that enables CopyStandardLogTo to connect +// Go's standard logs to the logs provided by this package. +type logBridge severity + +// Write parses the standard logging line and passes its components to the +// logger for severity(lb). +func (lb logBridge) Write(b []byte) (n int, err error) { + var ( + file = "???" + line = 1 + text string + ) + // Split "d.go:23: message" into "d.go", "23", and "message". + if parts := bytes.SplitN(b, []byte{':'}, 3); len(parts) != 3 || len(parts[0]) < 1 || len(parts[2]) < 1 { + text = fmt.Sprintf("bad log format: %s", b) + } else { + file = string(parts[0]) + text = string(parts[2][1:]) // skip leading space + line, err = strconv.Atoi(string(parts[1])) + if err != nil { + text = fmt.Sprintf("bad line number: %s", b) + line = 1 + } + } + // printWithFileLine with alsoToStderr=true, so standard log messages + // always appear on standard error. + logging.printWithFileLine(severity(lb), file, line, true, text) + return len(b), nil +} + +// setV computes and remembers the V level for a given PC +// when vmodule is enabled. +// File pattern matching takes the basename of the file, stripped +// of its .go suffix, and uses filepath.Match, which is a little more +// general than the *? matching used in C++. +// l.mu is held. +func (l *loggingT) setV(pc uintptr) Level { + fn := runtime.FuncForPC(pc) + file, _ := fn.FileLine(pc) + // The file is something like /a/b/c/d.go. We want just the d. + if strings.HasSuffix(file, ".go") { + file = file[:len(file)-3] + } + if slash := strings.LastIndex(file, "/"); slash >= 0 { + file = file[slash+1:] + } + for _, filter := range l.vmodule.filter { + if filter.match(file) { + l.vmap[pc] = filter.level + return filter.level + } + } + l.vmap[pc] = 0 + return 0 +} + +// Verbose is a boolean type that implements Infof (like Printf) etc. +// See the documentation of V for more information. +type Verbose bool + +// V reports whether verbosity at the call site is at least the requested level. +// The returned value is a boolean of type Verbose, which implements Info, Infoln +// and Infof. These methods will write to the Info log if called. +// Thus, one may write either +// if glog.V(2) { glog.Info("log this") } +// or +// glog.V(2).Info("log this") +// The second form is shorter but the first is cheaper if logging is off because it does +// not evaluate its arguments. +// +// Whether an individual call to V generates a log record depends on the setting of +// the -v and --vmodule flags; both are off by default. If the level in the call to +// V is at least the value of -v, or of -vmodule for the source file containing the +// call, the V call will log. +func V(level Level) Verbose { + // This function tries hard to be cheap unless there's work to do. + // The fast path is two atomic loads and compares. + + // Here is a cheap but safe test to see if V logging is enabled globally. + if logging.verbosity.get() >= level { + return Verbose(true) + } + + // It's off globally but it vmodule may still be set. + // Here is another cheap but safe test to see if vmodule is enabled. + if atomic.LoadInt32(&logging.filterLength) > 0 { + // Now we need a proper lock to use the logging structure. The pcs field + // is shared so we must lock before accessing it. This is fairly expensive, + // but if V logging is enabled we're slow anyway. + logging.mu.Lock() + defer logging.mu.Unlock() + if runtime.Callers(2, logging.pcs[:]) == 0 { + return Verbose(false) + } + v, ok := logging.vmap[logging.pcs[0]] + if !ok { + v = logging.setV(logging.pcs[0]) + } + return Verbose(v >= level) + } + return Verbose(false) +} + +// Info is equivalent to the global Info function, guarded by the value of v. +// See the documentation of V for usage. +func (v Verbose) Info(args ...interface{}) { + if v { + logging.print(infoLog, args...) + } +} + +// Infoln is equivalent to the global Infoln function, guarded by the value of v. +// See the documentation of V for usage. +func (v Verbose) Infoln(args ...interface{}) { + if v { + logging.println(infoLog, args...) + } +} + +// Infof is equivalent to the global Infof function, guarded by the value of v. +// See the documentation of V for usage. +func (v Verbose) Infof(format string, args ...interface{}) { + if v { + logging.printf(infoLog, format, args...) + } +} + +// Info logs to the INFO log. +// Arguments are handled in the manner of fmt.Print; a newline is appended if missing. +func Info(args ...interface{}) { + logging.print(infoLog, args...) +} + +// InfoDepth acts as Info but uses depth to determine which call frame to log. +// InfoDepth(0, "msg") is the same as Info("msg"). +func InfoDepth(depth int, args ...interface{}) { + logging.printDepth(infoLog, depth, args...) +} + +// Infoln logs to the INFO log. +// Arguments are handled in the manner of fmt.Println; a newline is appended if missing. +func Infoln(args ...interface{}) { + logging.println(infoLog, args...) +} + +// Infof logs to the INFO log. +// Arguments are handled in the manner of fmt.Printf; a newline is appended if missing. +func Infof(format string, args ...interface{}) { + logging.printf(infoLog, format, args...) +} + +// Warning logs to the WARNING and INFO logs. +// Arguments are handled in the manner of fmt.Print; a newline is appended if missing. +func Warning(args ...interface{}) { + logging.print(warningLog, args...) +} + +// WarningDepth acts as Warning but uses depth to determine which call frame to log. +// WarningDepth(0, "msg") is the same as Warning("msg"). +func WarningDepth(depth int, args ...interface{}) { + logging.printDepth(warningLog, depth, args...) +} + +// Warningln logs to the WARNING and INFO logs. +// Arguments are handled in the manner of fmt.Println; a newline is appended if missing. +func Warningln(args ...interface{}) { + logging.println(warningLog, args...) +} + +// Warningf logs to the WARNING and INFO logs. +// Arguments are handled in the manner of fmt.Printf; a newline is appended if missing. +func Warningf(format string, args ...interface{}) { + logging.printf(warningLog, format, args...) +} + +// Error logs to the ERROR, WARNING, and INFO logs. +// Arguments are handled in the manner of fmt.Print; a newline is appended if missing. +func Error(args ...interface{}) { + logging.print(errorLog, args...) +} + +// ErrorDepth acts as Error but uses depth to determine which call frame to log. +// ErrorDepth(0, "msg") is the same as Error("msg"). +func ErrorDepth(depth int, args ...interface{}) { + logging.printDepth(errorLog, depth, args...) +} + +// Errorln logs to the ERROR, WARNING, and INFO logs. +// Arguments are handled in the manner of fmt.Println; a newline is appended if missing. +func Errorln(args ...interface{}) { + logging.println(errorLog, args...) +} + +// Errorf logs to the ERROR, WARNING, and INFO logs. +// Arguments are handled in the manner of fmt.Printf; a newline is appended if missing. +func Errorf(format string, args ...interface{}) { + logging.printf(errorLog, format, args...) +} + +// Fatal logs to the FATAL, ERROR, WARNING, and INFO logs, +// including a stack trace of all running goroutines, then calls os.Exit(255). +// Arguments are handled in the manner of fmt.Print; a newline is appended if missing. +func Fatal(args ...interface{}) { + logging.print(fatalLog, args...) +} + +// FatalDepth acts as Fatal but uses depth to determine which call frame to log. +// FatalDepth(0, "msg") is the same as Fatal("msg"). +func FatalDepth(depth int, args ...interface{}) { + logging.printDepth(fatalLog, depth, args...) +} + +// Fatalln logs to the FATAL, ERROR, WARNING, and INFO logs, +// including a stack trace of all running goroutines, then calls os.Exit(255). +// Arguments are handled in the manner of fmt.Println; a newline is appended if missing. +func Fatalln(args ...interface{}) { + logging.println(fatalLog, args...) +} + +// Fatalf logs to the FATAL, ERROR, WARNING, and INFO logs, +// including a stack trace of all running goroutines, then calls os.Exit(255). +// Arguments are handled in the manner of fmt.Printf; a newline is appended if missing. +func Fatalf(format string, args ...interface{}) { + logging.printf(fatalLog, format, args...) +} + +// fatalNoStacks is non-zero if we are to exit without dumping goroutine stacks. +// It allows Exit and relatives to use the Fatal logs. +var fatalNoStacks uint32 + +// Exit logs to the FATAL, ERROR, WARNING, and INFO logs, then calls os.Exit(1). +// Arguments are handled in the manner of fmt.Print; a newline is appended if missing. +func Exit(args ...interface{}) { + atomic.StoreUint32(&fatalNoStacks, 1) + logging.print(fatalLog, args...) +} + +// ExitDepth acts as Exit but uses depth to determine which call frame to log. +// ExitDepth(0, "msg") is the same as Exit("msg"). +func ExitDepth(depth int, args ...interface{}) { + atomic.StoreUint32(&fatalNoStacks, 1) + logging.printDepth(fatalLog, depth, args...) +} + +// Exitln logs to the FATAL, ERROR, WARNING, and INFO logs, then calls os.Exit(1). +func Exitln(args ...interface{}) { + atomic.StoreUint32(&fatalNoStacks, 1) + logging.println(fatalLog, args...) +} + +// Exitf logs to the FATAL, ERROR, WARNING, and INFO logs, then calls os.Exit(1). +// Arguments are handled in the manner of fmt.Printf; a newline is appended if missing. +func Exitf(format string, args ...interface{}) { + atomic.StoreUint32(&fatalNoStacks, 1) + logging.printf(fatalLog, format, args...) +} diff --git a/Godeps/_workspace/src/github.com/golang/glog/glog_file.go b/Godeps/_workspace/src/github.com/golang/glog/glog_file.go new file mode 100644 index 0000000000..65075d2811 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/glog/glog_file.go @@ -0,0 +1,124 @@ +// Go support for leveled logs, analogous to https://code.google.com/p/google-glog/ +// +// Copyright 2013 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// File I/O for logs. + +package glog + +import ( + "errors" + "flag" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" + "sync" + "time" +) + +// MaxSize is the maximum size of a log file in bytes. +var MaxSize uint64 = 1024 * 1024 * 1800 + +// logDirs lists the candidate directories for new log files. +var logDirs []string + +// If non-empty, overrides the choice of directory in which to write logs. +// See createLogDirs for the full list of possible destinations. +var logDir = flag.String("log_dir", "", "If non-empty, write log files in this directory") + +func createLogDirs() { + if *logDir != "" { + logDirs = append(logDirs, *logDir) + } + logDirs = append(logDirs, os.TempDir()) +} + +var ( + pid = os.Getpid() + program = filepath.Base(os.Args[0]) + host = "unknownhost" + userName = "unknownuser" +) + +func init() { + h, err := os.Hostname() + if err == nil { + host = shortHostname(h) + } + + current, err := user.Current() + if err == nil { + userName = current.Username + } + + // Sanitize userName since it may contain filepath separators on Windows. + userName = strings.Replace(userName, `\`, "_", -1) +} + +// shortHostname returns its argument, truncating at the first period. +// For instance, given "www.google.com" it returns "www". +func shortHostname(hostname string) string { + if i := strings.Index(hostname, "."); i >= 0 { + return hostname[:i] + } + return hostname +} + +// logName returns a new log file name containing tag, with start time t, and +// the name for the symlink for tag. +func logName(tag string, t time.Time) (name, link string) { + name = fmt.Sprintf("%s.%s.%s.log.%s.%04d%02d%02d-%02d%02d%02d.%d", + program, + host, + userName, + tag, + t.Year(), + t.Month(), + t.Day(), + t.Hour(), + t.Minute(), + t.Second(), + pid) + return name, program + "." + tag +} + +var onceLogDirs sync.Once + +// create creates a new log file and returns the file and its filename, which +// contains tag ("INFO", "FATAL", etc.) and t. If the file is created +// successfully, create also attempts to update the symlink for that tag, ignoring +// errors. +func create(tag string, t time.Time) (f *os.File, filename string, err error) { + onceLogDirs.Do(createLogDirs) + if len(logDirs) == 0 { + return nil, "", errors.New("log: no log dirs") + } + name, link := logName(tag, t) + var lastErr error + for _, dir := range logDirs { + fname := filepath.Join(dir, name) + f, err := os.Create(fname) + if err == nil { + symlink := filepath.Join(dir, link) + os.Remove(symlink) // ignore err + os.Symlink(name, symlink) // ignore err + return f, fname, nil + } + lastErr = err + } + return nil, "", fmt.Errorf("log: cannot create log: %v", lastErr) +} diff --git a/Godeps/_workspace/src/github.com/golang/glog/glog_test.go b/Godeps/_workspace/src/github.com/golang/glog/glog_test.go new file mode 100644 index 0000000000..0fb376e1fd --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/glog/glog_test.go @@ -0,0 +1,415 @@ +// Go support for leveled logs, analogous to https://code.google.com/p/google-glog/ +// +// Copyright 2013 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package glog + +import ( + "bytes" + "fmt" + stdLog "log" + "path/filepath" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +// Test that shortHostname works as advertised. +func TestShortHostname(t *testing.T) { + for hostname, expect := range map[string]string{ + "": "", + "host": "host", + "host.google.com": "host", + } { + if got := shortHostname(hostname); expect != got { + t.Errorf("shortHostname(%q): expected %q, got %q", hostname, expect, got) + } + } +} + +// flushBuffer wraps a bytes.Buffer to satisfy flushSyncWriter. +type flushBuffer struct { + bytes.Buffer +} + +func (f *flushBuffer) Flush() error { + return nil +} + +func (f *flushBuffer) Sync() error { + return nil +} + +// swap sets the log writers and returns the old array. +func (l *loggingT) swap(writers [numSeverity]flushSyncWriter) (old [numSeverity]flushSyncWriter) { + l.mu.Lock() + defer l.mu.Unlock() + old = l.file + for i, w := range writers { + logging.file[i] = w + } + return +} + +// newBuffers sets the log writers to all new byte buffers and returns the old array. +func (l *loggingT) newBuffers() [numSeverity]flushSyncWriter { + return l.swap([numSeverity]flushSyncWriter{new(flushBuffer), new(flushBuffer), new(flushBuffer), new(flushBuffer)}) +} + +// contents returns the specified log value as a string. +func contents(s severity) string { + return logging.file[s].(*flushBuffer).String() +} + +// contains reports whether the string is contained in the log. +func contains(s severity, str string, t *testing.T) bool { + return strings.Contains(contents(s), str) +} + +// setFlags configures the logging flags how the test expects them. +func setFlags() { + logging.toStderr = false +} + +// Test that Info works as advertised. +func TestInfo(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + Info("test") + if !contains(infoLog, "I", t) { + t.Errorf("Info has wrong character: %q", contents(infoLog)) + } + if !contains(infoLog, "test", t) { + t.Error("Info failed") + } +} + +func TestInfoDepth(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + + f := func() { InfoDepth(1, "depth-test1") } + + // The next three lines must stay together + _, _, wantLine, _ := runtime.Caller(0) + InfoDepth(0, "depth-test0") + f() + + msgs := strings.Split(strings.TrimSuffix(contents(infoLog), "\n"), "\n") + if len(msgs) != 2 { + t.Fatalf("Got %d lines, expected 2", len(msgs)) + } + + for i, m := range msgs { + if !strings.HasPrefix(m, "I") { + t.Errorf("InfoDepth[%d] has wrong character: %q", i, m) + } + w := fmt.Sprintf("depth-test%d", i) + if !strings.Contains(m, w) { + t.Errorf("InfoDepth[%d] missing %q: %q", i, w, m) + } + + // pull out the line number (between : and ]) + msg := m[strings.LastIndex(m, ":")+1:] + x := strings.Index(msg, "]") + if x < 0 { + t.Errorf("InfoDepth[%d]: missing ']': %q", i, m) + continue + } + line, err := strconv.Atoi(msg[:x]) + if err != nil { + t.Errorf("InfoDepth[%d]: bad line number: %q", i, m) + continue + } + wantLine++ + if wantLine != line { + t.Errorf("InfoDepth[%d]: got line %d, want %d", i, line, wantLine) + } + } +} + +func init() { + CopyStandardLogTo("INFO") +} + +// Test that CopyStandardLogTo panics on bad input. +func TestCopyStandardLogToPanic(t *testing.T) { + defer func() { + if s, ok := recover().(string); !ok || !strings.Contains(s, "LOG") { + t.Errorf(`CopyStandardLogTo("LOG") should have panicked: %v`, s) + } + }() + CopyStandardLogTo("LOG") +} + +// Test that using the standard log package logs to INFO. +func TestStandardLog(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + stdLog.Print("test") + if !contains(infoLog, "I", t) { + t.Errorf("Info has wrong character: %q", contents(infoLog)) + } + if !contains(infoLog, "test", t) { + t.Error("Info failed") + } +} + +// Test that the header has the correct format. +func TestHeader(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + defer func(previous func() time.Time) { timeNow = previous }(timeNow) + timeNow = func() time.Time { + return time.Date(2006, 1, 2, 15, 4, 5, .067890e9, time.Local) + } + pid = 1234 + Info("test") + var line int + format := "I0102 15:04:05.067890 1234 glog_test.go:%d] test\n" + n, err := fmt.Sscanf(contents(infoLog), format, &line) + if n != 1 || err != nil { + t.Errorf("log format error: %d elements, error %s:\n%s", n, err, contents(infoLog)) + } + // Scanf treats multiple spaces as equivalent to a single space, + // so check for correct space-padding also. + want := fmt.Sprintf(format, line) + if contents(infoLog) != want { + t.Errorf("log format error: got:\n\t%q\nwant:\t%q", contents(infoLog), want) + } +} + +// Test that an Error log goes to Warning and Info. +// Even in the Info log, the source character will be E, so the data should +// all be identical. +func TestError(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + Error("test") + if !contains(errorLog, "E", t) { + t.Errorf("Error has wrong character: %q", contents(errorLog)) + } + if !contains(errorLog, "test", t) { + t.Error("Error failed") + } + str := contents(errorLog) + if !contains(warningLog, str, t) { + t.Error("Warning failed") + } + if !contains(infoLog, str, t) { + t.Error("Info failed") + } +} + +// Test that a Warning log goes to Info. +// Even in the Info log, the source character will be W, so the data should +// all be identical. +func TestWarning(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + Warning("test") + if !contains(warningLog, "W", t) { + t.Errorf("Warning has wrong character: %q", contents(warningLog)) + } + if !contains(warningLog, "test", t) { + t.Error("Warning failed") + } + str := contents(warningLog) + if !contains(infoLog, str, t) { + t.Error("Info failed") + } +} + +// Test that a V log goes to Info. +func TestV(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + logging.verbosity.Set("2") + defer logging.verbosity.Set("0") + V(2).Info("test") + if !contains(infoLog, "I", t) { + t.Errorf("Info has wrong character: %q", contents(infoLog)) + } + if !contains(infoLog, "test", t) { + t.Error("Info failed") + } +} + +// Test that a vmodule enables a log in this file. +func TestVmoduleOn(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + logging.vmodule.Set("glog_test=2") + defer logging.vmodule.Set("") + if !V(1) { + t.Error("V not enabled for 1") + } + if !V(2) { + t.Error("V not enabled for 2") + } + if V(3) { + t.Error("V enabled for 3") + } + V(2).Info("test") + if !contains(infoLog, "I", t) { + t.Errorf("Info has wrong character: %q", contents(infoLog)) + } + if !contains(infoLog, "test", t) { + t.Error("Info failed") + } +} + +// Test that a vmodule of another file does not enable a log in this file. +func TestVmoduleOff(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + logging.vmodule.Set("notthisfile=2") + defer logging.vmodule.Set("") + for i := 1; i <= 3; i++ { + if V(Level(i)) { + t.Errorf("V enabled for %d", i) + } + } + V(2).Info("test") + if contents(infoLog) != "" { + t.Error("V logged incorrectly") + } +} + +// vGlobs are patterns that match/don't match this file at V=2. +var vGlobs = map[string]bool{ + // Easy to test the numeric match here. + "glog_test=1": false, // If -vmodule sets V to 1, V(2) will fail. + "glog_test=2": true, + "glog_test=3": true, // If -vmodule sets V to 1, V(3) will succeed. + // These all use 2 and check the patterns. All are true. + "*=2": true, + "?l*=2": true, + "????_*=2": true, + "??[mno]?_*t=2": true, + // These all use 2 and check the patterns. All are false. + "*x=2": false, + "m*=2": false, + "??_*=2": false, + "?[abc]?_*t=2": false, +} + +// Test that vmodule globbing works as advertised. +func testVmoduleGlob(pat string, match bool, t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + defer logging.vmodule.Set("") + logging.vmodule.Set(pat) + if V(2) != Verbose(match) { + t.Errorf("incorrect match for %q: got %t expected %t", pat, V(2), match) + } +} + +// Test that a vmodule globbing works as advertised. +func TestVmoduleGlob(t *testing.T) { + for glob, match := range vGlobs { + testVmoduleGlob(glob, match, t) + } +} + +func TestRollover(t *testing.T) { + setFlags() + var err error + defer func(previous func(error)) { logExitFunc = previous }(logExitFunc) + logExitFunc = func(e error) { + err = e + } + defer func(previous uint64) { MaxSize = previous }(MaxSize) + MaxSize = 512 + + Info("x") // Be sure we have a file. + info, ok := logging.file[infoLog].(*syncBuffer) + if !ok { + t.Fatal("info wasn't created") + } + if err != nil { + t.Fatalf("info has initial error: %v", err) + } + fname0 := info.file.Name() + Info(strings.Repeat("x", int(MaxSize))) // force a rollover + if err != nil { + t.Fatalf("info has error after big write: %v", err) + } + + // Make sure the next log file gets a file name with a different + // time stamp. + // + // TODO: determine whether we need to support subsecond log + // rotation. C++ does not appear to handle this case (nor does it + // handle Daylight Savings Time properly). + time.Sleep(1 * time.Second) + + Info("x") // create a new file + if err != nil { + t.Fatalf("error after rotation: %v", err) + } + fname1 := info.file.Name() + if fname0 == fname1 { + t.Errorf("info.f.Name did not change: %v", fname0) + } + if info.nbytes >= MaxSize { + t.Errorf("file size was not reset: %d", info.nbytes) + } +} + +func TestLogBacktraceAt(t *testing.T) { + setFlags() + defer logging.swap(logging.newBuffers()) + // The peculiar style of this code simplifies line counting and maintenance of the + // tracing block below. + var infoLine string + setTraceLocation := func(file string, line int, ok bool, delta int) { + if !ok { + t.Fatal("could not get file:line") + } + _, file = filepath.Split(file) + infoLine = fmt.Sprintf("%s:%d", file, line+delta) + err := logging.traceLocation.Set(infoLine) + if err != nil { + t.Fatal("error setting log_backtrace_at: ", err) + } + } + { + // Start of tracing block. These lines know about each other's relative position. + _, file, line, ok := runtime.Caller(0) + setTraceLocation(file, line, ok, +2) // Two lines between Caller and Info calls. + Info("we want a stack trace here") + } + numAppearances := strings.Count(contents(infoLog), infoLine) + if numAppearances < 2 { + // Need 2 appearances, one in the log header and one in the trace: + // log_test.go:281: I0511 16:36:06.952398 02238 log_test.go:280] we want a stack trace here + // ... + // github.com/glog/glog_test.go:280 (0x41ba91) + // ... + // We could be more precise but that would require knowing the details + // of the traceback format, which may not be dependable. + t.Fatal("got no trace back; log is ", contents(infoLog)) + } +} + +func BenchmarkHeader(b *testing.B) { + for i := 0; i < b.N; i++ { + buf, _, _ := logging.header(infoLog, 0) + logging.putBuffer(buf) + } +} diff --git a/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/all_test.go b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/all_test.go new file mode 100644 index 0000000000..98d4e2d818 --- /dev/null +++ b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/all_test.go @@ -0,0 +1,355 @@ +// Copyright 2013 Matt T. Proud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + + . "code.google.com/p/goprotobuf/proto" + . "code.google.com/p/goprotobuf/proto/testdata" +) + +func TestWriteDelimited(t *testing.T) { + for _, test := range []struct { + msg Message + buf []byte + n int + err error + }{ + { + msg: &Empty{}, + n: 1, + buf: []byte{0}, + }, + { + msg: &GoEnum{Foo: FOO_FOO1.Enum()}, + n: 3, + buf: []byte{2, 8, 1}, + }, + { + msg: &Strings{ + StringField: String(`This is my gigantic, unhappy string. It exceeds +the encoding size of a single byte varint. We are using it to fuzz test the +correctness of the header decoding mechanisms, which may prove problematic. +I expect it may. Let's hope you enjoy testing as much as we do.`), + }, + n: 271, + buf: []byte{141, 2, 10, 138, 2, 84, 104, 105, 115, 32, 105, 115, 32, 109, + 121, 32, 103, 105, 103, 97, 110, 116, 105, 99, 44, 32, 117, 110, 104, + 97, 112, 112, 121, 32, 115, 116, 114, 105, 110, 103, 46, 32, 32, 73, + 116, 32, 101, 120, 99, 101, 101, 100, 115, 10, 116, 104, 101, 32, 101, + 110, 99, 111, 100, 105, 110, 103, 32, 115, 105, 122, 101, 32, 111, 102, + 32, 97, 32, 115, 105, 110, 103, 108, 101, 32, 98, 121, 116, 101, 32, + 118, 97, 114, 105, 110, 116, 46, 32, 32, 87, 101, 32, 97, 114, 101, 32, + 117, 115, 105, 110, 103, 32, 105, 116, 32, 116, 111, 32, 102, 117, 122, + 122, 32, 116, 101, 115, 116, 32, 116, 104, 101, 10, 99, 111, 114, 114, + 101, 99, 116, 110, 101, 115, 115, 32, 111, 102, 32, 116, 104, 101, 32, + 104, 101, 97, 100, 101, 114, 32, 100, 101, 99, 111, 100, 105, 110, 103, + 32, 109, 101, 99, 104, 97, 110, 105, 115, 109, 115, 44, 32, 119, 104, + 105, 99, 104, 32, 109, 97, 121, 32, 112, 114, 111, 118, 101, 32, 112, + 114, 111, 98, 108, 101, 109, 97, 116, 105, 99, 46, 10, 73, 32, 101, 120, + 112, 101, 99, 116, 32, 105, 116, 32, 109, 97, 121, 46, 32, 32, 76, 101, + 116, 39, 115, 32, 104, 111, 112, 101, 32, 121, 111, 117, 32, 101, 110, + 106, 111, 121, 32, 116, 101, 115, 116, 105, 110, 103, 32, 97, 115, 32, + 109, 117, 99, 104, 32, 97, 115, 32, 119, 101, 32, 100, 111, 46}, + }, + } { + var buf bytes.Buffer + if n, err := WriteDelimited(&buf, test.msg); n != test.n || err != test.err { + t.Fatalf("WriteDelimited(buf, %#v) = %v, %v; want %v, %v", test.msg, n, err, test.n, test.err) + } + if out := buf.Bytes(); !bytes.Equal(out, test.buf) { + t.Fatalf("WriteDelimited(buf, %#v); buf = %v; want %v", test.msg, out, test.buf) + } + } +} + +func TestReadDelimited(t *testing.T) { + for _, test := range []struct { + buf []byte + msg Message + n int + err error + }{ + { + buf: []byte{0}, + msg: &Empty{}, + n: 1, + }, + { + n: 3, + buf: []byte{2, 8, 1}, + msg: &GoEnum{Foo: FOO_FOO1.Enum()}, + }, + { + buf: []byte{141, 2, 10, 138, 2, 84, 104, 105, 115, 32, 105, 115, 32, 109, + 121, 32, 103, 105, 103, 97, 110, 116, 105, 99, 44, 32, 117, 110, 104, + 97, 112, 112, 121, 32, 115, 116, 114, 105, 110, 103, 46, 32, 32, 73, + 116, 32, 101, 120, 99, 101, 101, 100, 115, 10, 116, 104, 101, 32, 101, + 110, 99, 111, 100, 105, 110, 103, 32, 115, 105, 122, 101, 32, 111, 102, + 32, 97, 32, 115, 105, 110, 103, 108, 101, 32, 98, 121, 116, 101, 32, + 118, 97, 114, 105, 110, 116, 46, 32, 32, 87, 101, 32, 97, 114, 101, 32, + 117, 115, 105, 110, 103, 32, 105, 116, 32, 116, 111, 32, 102, 117, 122, + 122, 32, 116, 101, 115, 116, 32, 116, 104, 101, 10, 99, 111, 114, 114, + 101, 99, 116, 110, 101, 115, 115, 32, 111, 102, 32, 116, 104, 101, 32, + 104, 101, 97, 100, 101, 114, 32, 100, 101, 99, 111, 100, 105, 110, 103, + 32, 109, 101, 99, 104, 97, 110, 105, 115, 109, 115, 44, 32, 119, 104, + 105, 99, 104, 32, 109, 97, 121, 32, 112, 114, 111, 118, 101, 32, 112, + 114, 111, 98, 108, 101, 109, 97, 116, 105, 99, 46, 10, 73, 32, 101, 120, + 112, 101, 99, 116, 32, 105, 116, 32, 109, 97, 121, 46, 32, 32, 76, 101, + 116, 39, 115, 32, 104, 111, 112, 101, 32, 121, 111, 117, 32, 101, 110, + 106, 111, 121, 32, 116, 101, 115, 116, 105, 110, 103, 32, 97, 115, 32, + 109, 117, 99, 104, 32, 97, 115, 32, 119, 101, 32, 100, 111, 46}, + msg: &Strings{ + StringField: String(`This is my gigantic, unhappy string. It exceeds +the encoding size of a single byte varint. We are using it to fuzz test the +correctness of the header decoding mechanisms, which may prove problematic. +I expect it may. Let's hope you enjoy testing as much as we do.`), + }, + n: 271, + }, + } { + msg := Clone(test.msg) + msg.Reset() + if n, err := ReadDelimited(bytes.NewBuffer(test.buf), msg); n != test.n || err != test.err { + t.Fatalf("ReadDelimited(%v, msg) = %v, %v; want %v, %v", test.buf, n, err, test.n, test.err) + } + if !Equal(msg, test.msg) { + t.Fatalf("ReadDelimited(%v, msg); msg = %v; want %v", test.buf, msg, test.msg) + } + } +} + +func TestEndToEndValid(t *testing.T) { + for _, test := range [][]Message{ + []Message{&Empty{}}, + []Message{&GoEnum{Foo: FOO_FOO1.Enum()}, &Empty{}, &GoEnum{Foo: FOO_FOO1.Enum()}}, + []Message{&GoEnum{Foo: FOO_FOO1.Enum()}}, + []Message{&Strings{ + StringField: String(`This is my gigantic, unhappy string. It exceeds +the encoding size of a single byte varint. We are using it to fuzz test the +correctness of the header decoding mechanisms, which may prove problematic. +I expect it may. Let's hope you enjoy testing as much as we do.`), + }}, + } { + var buf bytes.Buffer + var written int + for i, msg := range test { + n, err := WriteDelimited(&buf, msg) + if err != nil { + // Assumption: TestReadDelimited and TestWriteDelimited are sufficient + // and inputs for this test are explicitly exercised there. + t.Fatalf("WriteDelimited(buf, %v[%d]) = ?, %v; wanted ?, nil", test, i, err) + } + written += n + } + var read int + for i, msg := range test { + out := Clone(msg) + out.Reset() + n, _ := ReadDelimited(&buf, out) + // Decide to do EOF checking? + read += n + if !Equal(out, msg) { + t.Fatalf("out = %v; want %v[%d] = %#v", out, test, i, msg) + } + } + if read != written { + t.Fatalf("%v read = %d; want %d", test, read, written) + } + } +} + +// visitMessage empties the private state fields of the quick.Value()-generated +// Protocol Buffer messages, for they cause an inordinate amount of problems. +// This is because we are using an automated fuzz generator on a type with +// private fields. +func visitMessage(m Message) { + t := reflect.TypeOf(m) + if t.Kind() != reflect.Ptr { + return + } + derefed := t.Elem() + if derefed.Kind() != reflect.Struct { + return + } + v := reflect.ValueOf(m) + elem := v.Elem() + for i := 0; i < elem.NumField(); i++ { + field := elem.FieldByIndex([]int{i}) + fieldType := field.Type() + if fieldType.Implements(reflect.TypeOf((*Message)(nil)).Elem()) { + visitMessage(field.Interface().(Message)) + } + if field.Kind() == reflect.Slice { + for i := 0; i < field.Len(); i++ { + elem := field.Index(i) + elemType := elem.Type() + if elemType.Implements(reflect.TypeOf((*Message)(nil)).Elem()) { + visitMessage(elem.Interface().(Message)) + } + } + } + } + if field := elem.FieldByName("XXX_unrecognized"); field.IsValid() { + field.Set(reflect.ValueOf([]byte{})) + } + if field := elem.FieldByName("XXX_extensions"); field.IsValid() { + field.Set(reflect.ValueOf(nil)) + } +} + +// rndMessage generates a random valid Protocol Buffer message. +func rndMessage(r *rand.Rand) Message { + var t reflect.Type + switch v := rand.Intn(23); v { + // TODO(br): Uncomment the elements below once fix is incorporated, except + // for the elements marked as patently incompatible. + // case 0: + // t = reflect.TypeOf(&GoEnum{}) + // break + // case 1: + // t = reflect.TypeOf(&GoTestField{}) + // break + case 2: + t = reflect.TypeOf(&GoTest{}) + break + // case 3: + // t = reflect.TypeOf(&GoSkipTest{}) + // break + // case 4: + // t = reflect.TypeOf(&NonPackedTest{}) + // break + // case 5: + // t = reflect.TypeOf(&PackedTest{}) + // break + case 6: + t = reflect.TypeOf(&MaxTag{}) + break + case 7: + t = reflect.TypeOf(&OldMessage{}) + break + case 8: + t = reflect.TypeOf(&NewMessage{}) + break + case 9: + t = reflect.TypeOf(&InnerMessage{}) + break + case 10: + t = reflect.TypeOf(&OtherMessage{}) + break + case 11: + // PATENTLY INVALID FOR FUZZ GENERATION + // t = reflect.TypeOf(&MyMessage{}) + break + // case 12: + // t = reflect.TypeOf(&Ext{}) + // break + case 13: + // PATENTLY INVALID FOR FUZZ GENERATION + // t = reflect.TypeOf(&MyMessageSet{}) + break + // case 14: + // t = reflect.TypeOf(&Empty{}) + // break + // case 15: + // t = reflect.TypeOf(&MessageList{}) + // break + // case 16: + // t = reflect.TypeOf(&Strings{}) + // break + // case 17: + // t = reflect.TypeOf(&Defaults{}) + // break + // case 17: + // t = reflect.TypeOf(&SubDefaults{}) + // break + // case 18: + // t = reflect.TypeOf(&RepeatedEnum{}) + // break + case 19: + t = reflect.TypeOf(&MoreRepeated{}) + break + // case 20: + // t = reflect.TypeOf(&GroupOld{}) + // break + // case 21: + // t = reflect.TypeOf(&GroupNew{}) + // break + case 22: + t = reflect.TypeOf(&FloatingPoint{}) + break + default: + // TODO(br): Replace with an unreachable once fixed. + t = reflect.TypeOf(&GoTest{}) + break + } + if t == nil { + t = reflect.TypeOf(&GoTest{}) + } + v, ok := quick.Value(t, r) + if !ok { + panic("attempt to generate illegal item; consult item 11") + } + visitMessage(v.Interface().(Message)) + return v.Interface().(Message) +} + +// rndMessages generates several random Protocol Buffer messages. +func rndMessages(r *rand.Rand) []Message { + n := r.Intn(128) + out := make([]Message, 0, n) + for i := 0; i < n; i++ { + out = append(out, rndMessage(r)) + } + return out +} + +func TestFuzz(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + check := func() bool { + messages := rndMessages(rnd) + var buf bytes.Buffer + var written int + for i, msg := range messages { + n, err := WriteDelimited(&buf, msg) + if err != nil { + t.Fatalf("WriteDelimited(buf, %v[%d]) = ?, %v; wanted ?, nil", messages, i, err) + } + written += n + } + var read int + for i, msg := range messages { + out := Clone(msg) + out.Reset() + n, _ := ReadDelimited(&buf, out) + read += n + if !Equal(out, msg) { + t.Fatalf("out = %v; want %v[%d] = %#v", out, messages, i, msg) + } + } + if read != written { + t.Fatalf("%v read = %d; want %d", messages, read, written) + } + return true + } + if err := quick.Check(check, nil); err != nil { + t.Fatal(err) + } +} diff --git a/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/decode.go b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/decode.go new file mode 100644 index 0000000000..a9af23333c --- /dev/null +++ b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/decode.go @@ -0,0 +1,75 @@ +// Copyright 2013 Matt T. Proud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "encoding/binary" + "errors" + "io" + + "code.google.com/p/goprotobuf/proto" +) + +var errInvalidVarint = errors.New("invalid varint32 encountered") + +// ReadDelimited decodes a message from the provided length-delimited stream, +// where the length is encoded as 32-bit varint prefix to the message body. +// It returns the total number of bytes read and any applicable error. This is +// roughly equivalent to the companion Java API's +// MessageLite#parseDelimitedFrom. As per the reader contract, this function +// calls r.Read repeatedly as required until exactly one message including its +// prefix is read and decoded (or an error has occurred). The function never +// reads more bytes from the stream than required. The function never returns +// an error if a message has been read and decoded correctly, even if the end +// of the stream has been reached in doing so. In that case, any subsequent +// calls return (0, io.EOF). +func ReadDelimited(r io.Reader, m proto.Message) (n int, err error) { + // Per AbstractParser#parsePartialDelimitedFrom with + // CodedInputStream#readRawVarint32. + headerBuf := make([]byte, binary.MaxVarintLen32) + var bytesRead, varIntBytes int + var messageLength uint64 + for varIntBytes == 0 { // i.e. no varint has been decoded yet. + if bytesRead >= len(headerBuf) { + return bytesRead, errInvalidVarint + } + // We have to read byte by byte here to avoid reading more bytes + // than required. Each read byte is appended to what we have + // read before. + newBytesRead, err := r.Read(headerBuf[bytesRead : bytesRead+1]) + if newBytesRead == 0 { + if err != nil { + return bytesRead, err + } + // A Reader should not return (0, nil), but if it does, + // it should be treated as no-op (according to the + // Reader contract). So let's go on... + continue + } + bytesRead += newBytesRead + // Now present everything read so far to the varint decoder and + // see if a varint can be decoded already. + messageLength, varIntBytes = proto.DecodeVarint(headerBuf[:bytesRead]) + } + + messageBuf := make([]byte, messageLength) + newBytesRead, err := io.ReadFull(r, messageBuf) + bytesRead += newBytesRead + if err != nil { + return bytesRead, err + } + + return bytesRead, proto.Unmarshal(messageBuf, m) +} diff --git a/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/doc.go b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/doc.go new file mode 100644 index 0000000000..652907bd2e --- /dev/null +++ b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/doc.go @@ -0,0 +1,16 @@ +// Copyright 2013 Matt T. Proud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ext enables record length-delimited Protocol Buffer streaming. +package ext diff --git a/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/encode.go b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/encode.go new file mode 100644 index 0000000000..a7a9345615 --- /dev/null +++ b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/encode.go @@ -0,0 +1,46 @@ +// Copyright 2013 Matt T. Proud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "encoding/binary" + "io" + + "code.google.com/p/goprotobuf/proto" +) + +// WriteDelimited encodes and dumps a message to the provided writer prefixed +// with a 32-bit varint indicating the length of the encoded message, producing +// a length-delimited record stream, which can be used to chain together +// encoded messages of the same type together in a file. It returns the total +// number of bytes written and any applicable error. This is roughly +// equivalent to the companion Java API's MessageLite#writeDelimitedTo. +func WriteDelimited(w io.Writer, m proto.Message) (n int, err error) { + buffer, err := proto.Marshal(m) + if err != nil { + return 0, err + } + + buf := make([]byte, binary.MaxVarintLen32) + encodedLength := binary.PutUvarint(buf, uint64(len(buffer))) + + sync, err := w.Write(buf[:encodedLength]) + if err != nil { + return sync, err + } + + n, err = w.Write(buffer) + return n + sync, err +} diff --git a/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/fixtures_test.go b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/fixtures_test.go new file mode 100644 index 0000000000..9cfcbbf585 --- /dev/null +++ b/Godeps/_workspace/src/github.com/matttproud/golang_protobuf_extensions/ext/fixtures_test.go @@ -0,0 +1,103 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// http://code.google.com/p/goprotobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package ext + +import ( + . "code.google.com/p/goprotobuf/proto" + . "code.google.com/p/goprotobuf/proto/testdata" +) + +// FROM https://code.google.com/p/goprotobuf/source/browse/proto/all_test.go. + +func initGoTestField() *GoTestField { + f := new(GoTestField) + f.Label = String("label") + f.Type = String("type") + return f +} + +// These are all structurally equivalent but the tag numbers differ. +// (It's remarkable that required, optional, and repeated all have +// 8 letters.) +func initGoTest_RequiredGroup() *GoTest_RequiredGroup { + return &GoTest_RequiredGroup{ + RequiredField: String("required"), + } +} + +func initGoTest_OptionalGroup() *GoTest_OptionalGroup { + return &GoTest_OptionalGroup{ + RequiredField: String("optional"), + } +} + +func initGoTest_RepeatedGroup() *GoTest_RepeatedGroup { + return &GoTest_RepeatedGroup{ + RequiredField: String("repeated"), + } +} + +func initGoTest(setdefaults bool) *GoTest { + pb := new(GoTest) + if setdefaults { + pb.F_BoolDefaulted = Bool(Default_GoTest_F_BoolDefaulted) + pb.F_Int32Defaulted = Int32(Default_GoTest_F_Int32Defaulted) + pb.F_Int64Defaulted = Int64(Default_GoTest_F_Int64Defaulted) + pb.F_Fixed32Defaulted = Uint32(Default_GoTest_F_Fixed32Defaulted) + pb.F_Fixed64Defaulted = Uint64(Default_GoTest_F_Fixed64Defaulted) + pb.F_Uint32Defaulted = Uint32(Default_GoTest_F_Uint32Defaulted) + pb.F_Uint64Defaulted = Uint64(Default_GoTest_F_Uint64Defaulted) + pb.F_FloatDefaulted = Float32(Default_GoTest_F_FloatDefaulted) + pb.F_DoubleDefaulted = Float64(Default_GoTest_F_DoubleDefaulted) + pb.F_StringDefaulted = String(Default_GoTest_F_StringDefaulted) + pb.F_BytesDefaulted = Default_GoTest_F_BytesDefaulted + pb.F_Sint32Defaulted = Int32(Default_GoTest_F_Sint32Defaulted) + pb.F_Sint64Defaulted = Int64(Default_GoTest_F_Sint64Defaulted) + } + + pb.Kind = GoTest_TIME.Enum() + pb.RequiredField = initGoTestField() + pb.F_BoolRequired = Bool(true) + pb.F_Int32Required = Int32(3) + pb.F_Int64Required = Int64(6) + pb.F_Fixed32Required = Uint32(32) + pb.F_Fixed64Required = Uint64(64) + pb.F_Uint32Required = Uint32(3232) + pb.F_Uint64Required = Uint64(6464) + pb.F_FloatRequired = Float32(3232) + pb.F_DoubleRequired = Float64(6464) + pb.F_StringRequired = String("string") + pb.F_BytesRequired = []byte("bytes") + pb.F_Sint32Required = Int32(-32) + pb.F_Sint64Required = Int64(-64) + pb.Requiredgroup = initGoTest_RequiredGroup() + + return pb +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/.gitignore b/Godeps/_workspace/src/github.com/miekg/dns/.gitignore new file mode 100644 index 0000000000..776cd950c2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/.gitignore @@ -0,0 +1,4 @@ +*.6 +tags +test.out +a.out diff --git a/Godeps/_workspace/src/github.com/miekg/dns/.travis.yml b/Godeps/_workspace/src/github.com/miekg/dns/.travis.yml new file mode 100644 index 0000000000..acf314c6df --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/.travis.yml @@ -0,0 +1,19 @@ +language: go +go: + - 1.2 + - 1.3 +env: + # "gvm update" resets GOOS and GOARCH environment variable, workaround it by setting + # BUILD_GOOS and BUILD_GOARCH and overriding GOARCH and GOOS in the build script + global: + - BUILD_GOARCH=amd64 + matrix: + - BUILD_GOOS=linux + - BUILD_GOOS=darwin + - BUILD_GOOS=windows +script: + - gvm cross $BUILD_GOOS $BUILD_GOARCH + - GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go build + + # only test on linux + - if [ $BUILD_GOOS == "linux" ]; then GOARCH=$BUILD_GOARCH GOOS=$BUILD_GOOS go test -bench=.; fi diff --git a/Godeps/_workspace/src/github.com/miekg/dns/AUTHORS b/Godeps/_workspace/src/github.com/miekg/dns/AUTHORS new file mode 100644 index 0000000000..1965683525 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/AUTHORS @@ -0,0 +1 @@ +Miek Gieben diff --git a/Godeps/_workspace/src/github.com/miekg/dns/CONTRIBUTORS b/Godeps/_workspace/src/github.com/miekg/dns/CONTRIBUTORS new file mode 100644 index 0000000000..f77e8a895f --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/CONTRIBUTORS @@ -0,0 +1,9 @@ +Alex A. Skinner +Andrew Tunnell-Jones +Ask Bjørn Hansen +Dave Cheney +Dusty Wilson +Marek Majkowski +Peter van Dijk +Omri Bahumi +Alex Sergeyev diff --git a/Godeps/_workspace/src/github.com/miekg/dns/COPYRIGHT b/Godeps/_workspace/src/github.com/miekg/dns/COPYRIGHT new file mode 100644 index 0000000000..35702b10e8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/COPYRIGHT @@ -0,0 +1,9 @@ +Copyright 2009 The Go Authors. All rights reserved. Use of this source code +is governed by a BSD-style license that can be found in the LICENSE file. +Extensions of the original work are copyright (c) 2011 Miek Gieben + +Copyright 2011 Miek Gieben. All rights reserved. Use of this source code is +governed by a BSD-style license that can be found in the LICENSE file. + +Copyright 2014 CloudFlare. All rights reserved. Use of this source code is +governed by a BSD-style license that can be found in the LICENSE file. diff --git a/Godeps/_workspace/src/github.com/miekg/dns/LICENSE b/Godeps/_workspace/src/github.com/miekg/dns/LICENSE new file mode 100644 index 0000000000..5763fa7fe5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/LICENSE @@ -0,0 +1,32 @@ +Extensions of the original work are copyright (c) 2011 Miek Gieben + +As this is fork of the official Go code the same license applies: + +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/Godeps/_workspace/src/github.com/miekg/dns/README.md b/Godeps/_workspace/src/github.com/miekg/dns/README.md new file mode 100644 index 0000000000..20fe64ba26 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/README.md @@ -0,0 +1,140 @@ +[![Build Status](https://travis-ci.org/miekg/dns.svg?branch=master)](https://travis-ci.org/miekg/dns) + +# Alternative (more granular) approach to a DNS library + +> Less is more. + +Complete and usable DNS library. All widely used Resource Records are +supported, including the DNSSEC types. It follows a lean and mean philosophy. +If there is stuff you should know as a DNS programmer there isn't a convenience +function for it. Server side and client side programming is supported, i.e. you +can build servers and resolvers with it. + +If you like this, you may also be interested in: + +* https://github.com/miekg/unbound -- Go wrapper for the Unbound resolver. + +# Goals + +* KISS; +* Fast; +* Small API, if its easy to code in Go, don't make a function for it. + +# Users + +A not-so-up-to-date-list-that-may-be-actually-current: + +* https://github.com/abh/geodns +* http://www.statdns.com/ +* http://www.dnsinspect.com/ +* https://github.com/chuangbo/jianbing-dictionary-dns +* http://www.dns-lg.com/ +* https://github.com/fcambus/rrda +* https://github.com/kenshinx/godns +* https://github.com/skynetservices/skydns +* https://github.com/DevelopersPL/godnsagent +* https://github.com/duedil-ltd/discodns + +Send pull request if you want to be listed here. + +# Features + +* UDP/TCP queries, IPv4 and IPv6; +* RFC 1035 zone file parsing ($INCLUDE, $ORIGIN, $TTL and $GENERATE (for all record types) are supported; +* Fast: + * Reply speed around ~ 80K qps (faster hardware results in more qps); + * Parsing RRs ~ 100K RR/s, that's 5M records in about 50 seconds; +* Server side programming (mimicking the net/http package); +* Client side programming; +* DNSSEC: signing, validating and key generation for DSA, RSA and ECDSA; +* EDNS0, NSID; +* AXFR/IXFR; +* TSIG; +* DNS name compression; +* Depends only on the standard library. + +Have fun! + +Miek Gieben - 2010-2012 - + +# Building + +Building is done with the `go` tool. If you have setup your GOPATH +correctly, the following should work: + + go get github.com/miekg/dns + go build github.com/miekg/dns + +## Examples + +A short "how to use the API" is at the beginning of dns.go (this also will show +when you call `godoc github.com/miekg/dns`). + +Example programs can be found in the `github.com/miekg/exdns` repository. + +## Supported RFCs + +*all of them* + +* 103{4,5} - DNS standard +* 1348 - NSAP record +* 1982 - Serial Arithmetic +* 1876 - LOC record +* 1995 - IXFR +* 1996 - DNS notify +* 2136 - DNS Update (dynamic updates) +* 2181 - RRset definition - there is no RRset type though, just []RR +* 2537 - RSAMD5 DNS keys +* 2065 - DNSSEC (updated in later RFCs) +* 2671 - EDNS record +* 2782 - SRV record +* 2845 - TSIG record +* 2915 - NAPTR record +* 2929 - DNS IANA Considerations +* 3110 - RSASHA1 DNS keys +* 3225 - DO bit (DNSSEC OK) +* 340{1,2,3} - NAPTR record +* 3445 - Limiting the scope of (DNS)KEY +* 3597 - Unkown RRs +* 403{3,4,5} - DNSSEC + validation functions +* 4255 - SSHFP record +* 4343 - Case insensitivity +* 4408 - SPF record +* 4509 - SHA256 Hash in DS +* 4592 - Wildcards in the DNS +* 4635 - HMAC SHA TSIG +* 4701 - DHCID +* 4892 - id.server +* 5001 - NSID +* 5155 - NSEC3 record +* 5205 - HIP record +* 5702 - SHA2 in the DNS +* 5936 - AXFR +* 5966 - TCP implementation recommendations +* 6605 - ECDSA +* 6742 - ILNP DNS +* 6891 - EDNS0 update +* 6895 - DNS IANA considerations +* 6975 - Algorithm Understanding in DNSSEC +* 7043 - EUI48/EUI64 records +* 7314 - DNS (EDNS) EXPIRE Option +* xxxx - URI record (draft) +* xxxx - EDNS0 DNS Update Lease (draft) + +## Loosely based upon + +* `ldns` +* `NSD` +* `Net::DNS` +* `GRONG` + +## TODO + +* privatekey.Precompute() when signing? +* Last remaining RRs: APL, ATMA, A6 and NXT; +* Missing in parsing: ISDN, UNSPEC, ATMA; +* CAA parsing is broken; +* NSEC(3) cover/match/closest enclose; +* Replies with TC bit are not parsed to the end; +* SIG(0); +* Create IsMsg to validate a message before fully parsing it. diff --git a/Godeps/_workspace/src/github.com/miekg/dns/client.go b/Godeps/_workspace/src/github.com/miekg/dns/client.go new file mode 100644 index 0000000000..ee8e22333b --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/client.go @@ -0,0 +1,319 @@ +package dns + +// A client implementation. + +import ( + "bytes" + "io" + "net" + "time" +) + +const dnsTimeout time.Duration = 2 * 1e9 +const tcpIdleTimeout time.Duration = 8 * time.Second + +// A Conn represents a connection to a DNS server. +type Conn struct { + net.Conn // a net.Conn holding the connection + UDPSize uint16 // minimum receive buffer for UDP messages + TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified + rtt time.Duration + t time.Time + tsigRequestMAC string +} + +// A Client defines parameters for a DNS client. +type Client struct { + Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) + UDPSize uint16 // minimum receive buffer for UDP messages + DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 + TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified + SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass + group singleflight +} + +// Exchange performs a synchronous UDP query. It sends the message m to the address +// contained in a and waits for an reply. Exchange does not retry a failed query, nor +// will it fall back to TCP in case of truncation. +// If you need to send a DNS message on an already existing connection, you can use the +// following: +// +// co := &dns.Conn{Conn: c} // c is your net.Conn +// co.WriteMsg(m) +// in, err := co.ReadMsg() +// co.Close() +// +func Exchange(m *Msg, a string) (r *Msg, err error) { + var co *Conn + co, err = DialTimeout("udp", a, dnsTimeout) + if err != nil { + return nil, err + } + + defer co.Close() + co.SetReadDeadline(time.Now().Add(dnsTimeout)) + co.SetWriteDeadline(time.Now().Add(dnsTimeout)) + if err = co.WriteMsg(m); err != nil { + return nil, err + } + r, err = co.ReadMsg() + return r, err +} + +// ExchangeConn performs a synchronous query. It sends the message m via the connection +// c and waits for a reply. The connection c is not closed by ExchangeConn. +// This function is going away, but can easily be mimicked: +// +// co := &dns.Conn{Conn: c} // c is your net.Conn +// co.WriteMsg(m) +// in, _ := co.ReadMsg() +// co.Close() +// +func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { + println("dns: this function is deprecated") + co := new(Conn) + co.Conn = c + if err = co.WriteMsg(m); err != nil { + return nil, err + } + r, err = co.ReadMsg() + return r, err +} + +// Exchange performs an synchronous query. It sends the message m to the address +// contained in a and waits for an reply. Basic use pattern with a *dns.Client: +// +// c := new(dns.Client) +// in, rtt, err := c.Exchange(message, "127.0.0.1:53") +// +// Exchange does not retry a failed query, nor will it fall back to TCP in +// case of truncation. +func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { + if !c.SingleInflight { + return c.exchange(m, a) + } + // This adds a bunch of garbage, TODO(miek). + t := "nop" + if t1, ok := TypeToString[m.Question[0].Qtype]; ok { + t = t1 + } + cl := "nop" + if cl1, ok := ClassToString[m.Question[0].Qclass]; ok { + cl = cl1 + } + r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { + return c.exchange(m, a) + }) + if err != nil { + return r, rtt, err + } + if shared { + return r.Copy(), rtt, nil + } + return r, rtt, nil +} + +func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { + timeout := dnsTimeout + var co *Conn + if c.DialTimeout != 0 { + timeout = c.DialTimeout + } + if c.Net == "" { + co, err = DialTimeout("udp", a, timeout) + } else { + co, err = DialTimeout(c.Net, a, timeout) + } + if err != nil { + return nil, 0, err + } + timeout = dnsTimeout + if c.ReadTimeout != 0 { + timeout = c.ReadTimeout + } + co.SetReadDeadline(time.Now().Add(timeout)) + timeout = dnsTimeout + if c.WriteTimeout != 0 { + timeout = c.WriteTimeout + } + co.SetWriteDeadline(time.Now().Add(timeout)) + defer co.Close() + opt := m.IsEdns0() + // If EDNS0 is used use that for size. + if opt != nil && opt.UDPSize() >= MinMsgSize { + co.UDPSize = opt.UDPSize() + } + // Otherwise use the client's configured UDP size. + if opt == nil && c.UDPSize >= MinMsgSize { + co.UDPSize = c.UDPSize + } + co.TsigSecret = c.TsigSecret + if err = co.WriteMsg(m); err != nil { + return nil, 0, err + } + r, err = co.ReadMsg() + return r, co.rtt, err +} + +// ReadMsg reads a message from the connection co. +// If the received message contains a TSIG record the transaction +// signature is verified. +func (co *Conn) ReadMsg() (*Msg, error) { + var p []byte + m := new(Msg) + if _, ok := co.Conn.(*net.TCPConn); ok { + p = make([]byte, MaxMsgSize) + } else { + if co.UDPSize >= 512 { + p = make([]byte, co.UDPSize) + } else { + p = make([]byte, MinMsgSize) + } + } + n, err := co.Read(p) + if err != nil && n == 0 { + return nil, err + } + p = p[:n] + if err := m.Unpack(p); err != nil { + return nil, err + } + co.rtt = time.Since(co.t) + if t := m.IsTsig(); t != nil { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + } + return m, err +} + +// Read implements the net.Conn read method. +func (co *Conn) Read(p []byte) (n int, err error) { + if co.Conn == nil { + return 0, ErrConnEmpty + } + if len(p) < 2 { + return 0, io.ErrShortBuffer + } + if t, ok := co.Conn.(*net.TCPConn); ok { + n, err = t.Read(p[0:2]) + if err != nil || n != 2 { + return n, err + } + l, _ := unpackUint16(p[0:2], 0) + if l == 0 { + return 0, ErrShortRead + } + if int(l) > len(p) { + return int(l), io.ErrShortBuffer + } + n, err = t.Read(p[:l]) + if err != nil { + return n, err + } + i := n + for i < int(l) { + j, err := t.Read(p[i:int(l)]) + if err != nil { + return i, err + } + i += j + } + n = i + return n, err + } + // UDP connection + n, err = co.Conn.Read(p) + if err != nil { + return n, err + } + return n, err +} + +// WriteMsg sends a message throught the connection co. +// If the message m contains a TSIG record the transaction +// signature is calculated. +func (co *Conn) WriteMsg(m *Msg) (err error) { + var out []byte + if t := m.IsTsig(); t != nil { + mac := "" + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + // Set for the next read, allthough only used in zone transfers + co.tsigRequestMAC = mac + } else { + out, err = m.Pack() + } + if err != nil { + return err + } + co.t = time.Now() + if _, err = co.Write(out); err != nil { + return err + } + return nil +} + +// Write implements the net.Conn Write method. +func (co *Conn) Write(p []byte) (n int, err error) { + if t, ok := co.Conn.(*net.TCPConn); ok { + lp := len(p) + if lp < 2 { + return 0, io.ErrShortBuffer + } + if lp > MaxMsgSize { + return 0, &Error{err: "message too large"} + } + l := make([]byte, 2, lp+2) + l[0], l[1] = packUint16(uint16(lp)) + p = append(l, p...) + n, err := io.Copy(t, bytes.NewReader(p)) + return int(n), err + } + n, err = co.Conn.(*net.UDPConn).Write(p) + return n, err +} + +// Dial connects to the address on the named network. +func Dial(network, address string) (conn *Conn, err error) { + conn = new(Conn) + conn.Conn, err = net.Dial(network, address) + if err != nil { + return nil, err + } + return conn, nil +} + +// Dialtimeout acts like Dial but takes a timeout. +func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { + conn = new(Conn) + conn.Conn, err = net.DialTimeout(network, address, timeout) + if err != nil { + return nil, err + } + return conn, nil +} + +// Close implements the net.Conn Close method. +func (co *Conn) Close() error { return co.Conn.Close() } + +// LocalAddr implements the net.Conn LocalAddr method. +func (co *Conn) LocalAddr() net.Addr { return co.Conn.LocalAddr() } + +// RemoteAddr implements the net.Conn RemoteAddr method. +func (co *Conn) RemoteAddr() net.Addr { return co.Conn.RemoteAddr() } + +// SetDeadline implements the net.Conn SetDeadline method. +func (co *Conn) SetDeadline(t time.Time) error { return co.Conn.SetDeadline(t) } + +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (co *Conn) SetReadDeadline(t time.Time) error { return co.Conn.SetReadDeadline(t) } + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (co *Conn) SetWriteDeadline(t time.Time) error { return co.Conn.SetWriteDeadline(t) } diff --git a/Godeps/_workspace/src/github.com/miekg/dns/client_test.go b/Godeps/_workspace/src/github.com/miekg/dns/client_test.go new file mode 100644 index 0000000000..9691a9cd2a --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/client_test.go @@ -0,0 +1,195 @@ +package dns + +import ( + "testing" + "time" +) + +func TestClientSync(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c := new(Client) + r, _, e := c.Exchange(m, addrstr) + if e != nil { + t.Logf("failed to exchange: %s", e.Error()) + t.Fail() + } + if r != nil && r.Rcode != RcodeSuccess { + t.Log("failed to get an valid answer") + t.Fail() + t.Logf("%v\n", r) + } + // And now with plain Exchange(). + r, e = Exchange(m, addrstr) + if e != nil { + t.Logf("failed to exchange: %s", e.Error()) + t.Fail() + } + if r != nil && r.Rcode != RcodeSuccess { + t.Log("failed to get an valid answer") + t.Fail() + t.Logf("%v\n", r) + } +} + +func TestClientEDNS0(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeDNSKEY) + + m.SetEdns0(2048, true) + + c := new(Client) + r, _, e := c.Exchange(m, addrstr) + if e != nil { + t.Logf("failed to exchange: %s", e.Error()) + t.Fail() + } + + if r != nil && r.Rcode != RcodeSuccess { + t.Log("failed to get an valid answer") + t.Fail() + t.Logf("%v\n", r) + } +} + +func TestSingleSingleInflight(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeDNSKEY) + + c := new(Client) + c.SingleInflight = true + nr := 10 + ch := make(chan time.Duration) + for i := 0; i < nr; i++ { + go func() { + _, rtt, _ := c.Exchange(m, addrstr) + ch <- rtt + }() + } + i := 0 + var first time.Duration + // With inflight *all* rtt are identical, and by doing actual lookups + // the changes that this is a coincidence is small. +Loop: + for { + select { + case rtt := <-ch: + if i == 0 { + first = rtt + } else { + if first != rtt { + t.Log("all rtts should be equal") + t.Fail() + } + } + i++ + if i == 10 { + break Loop + } + } + } +} + +/* +func TestClientTsigAXFR(t *testing.T) { + m := new(Msg) + m.SetAxfr("example.nl.") + m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix()) + + tr := new(Transfer) + tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} + + if a, err := tr.In(m, "176.58.119.54:53"); err != nil { + t.Log("failed to setup axfr: " + err.Error()) + t.Fatal() + } else { + for ex := range a { + if ex.Error != nil { + t.Logf("error %s\n", ex.Error.Error()) + t.Fail() + break + } + for _, rr := range ex.RR { + t.Logf("%s\n", rr.String()) + } + } + } +} + +func TestClientAXFRMultipleEnvelopes(t *testing.T) { + m := new(Msg) + m.SetAxfr("nlnetlabs.nl.") + + tr := new(Transfer) + if a, err := tr.In(m, "213.154.224.1:53"); err != nil { + t.Log("Failed to setup axfr" + err.Error()) + t.Fail() + return + } else { + for ex := range a { + if ex.Error != nil { + t.Logf("Error %s\n", ex.Error.Error()) + t.Fail() + break + } + } + } +} +*/ + +// ExapleUpdateLeaseTSIG shows how to update a lease signed with TSIG. +func ExampleUpdateLeaseTSIG(t *testing.T) { + m := new(Msg) + m.SetUpdate("t.local.ip6.io.") + rr, _ := NewRR("t.local.ip6.io. 30 A 127.0.0.1") + rrs := make([]RR, 1) + rrs[0] = rr + m.Insert(rrs) + + lease_rr := new(OPT) + lease_rr.Hdr.Name = "." + lease_rr.Hdr.Rrtype = TypeOPT + e := new(EDNS0_UL) + e.Code = EDNS0UL + e.Lease = 120 + lease_rr.Option = append(lease_rr.Option, e) + m.Extra = append(m.Extra, lease_rr) + + c := new(Client) + m.SetTsig("polvi.", HmacMD5, 300, time.Now().Unix()) + c.TsigSecret = map[string]string{"polvi.": "pRZgBrBvI4NAHZYhxmhs/Q=="} + + _, _, err := c.Exchange(m, "127.0.0.1:53") + if err != nil { + t.Log(err.Error()) + t.Fail() + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/clientconfig.go b/Godeps/_workspace/src/github.com/miekg/dns/clientconfig.go new file mode 100644 index 0000000000..87cf896186 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/clientconfig.go @@ -0,0 +1,94 @@ +package dns + +import ( + "bufio" + "os" + "strconv" + "strings" +) + +// Wraps the contents of the /etc/resolv.conf. +type ClientConfig struct { + Servers []string // servers to use + Search []string // suffixes to append to local name + Port string // what port to use + Ndots int // number of dots in name to trigger absolute lookup + Timeout int // seconds before giving up on packet + Attempts int // lost packets before giving up on server, not used in the package dns +} + +// ClientConfigFromFile parses a resolv.conf(5) like file and returns +// a *ClientConfig. +func ClientConfigFromFile(resolvconf string) (*ClientConfig, error) { + file, err := os.Open(resolvconf) + if err != nil { + return nil, err + } + defer file.Close() + c := new(ClientConfig) + b := bufio.NewReader(file) + c.Servers = make([]string, 0) + c.Search = make([]string, 0) + c.Port = "53" + c.Ndots = 1 + c.Timeout = 5 + c.Attempts = 2 + for line, ok := b.ReadString('\n'); ok == nil; line, ok = b.ReadString('\n') { + f := strings.Fields(line) + if len(f) < 1 { + continue + } + switch f[0] { + case "nameserver": // add one name server + if len(f) > 1 { + // One more check: make sure server name is + // just an IP address. Otherwise we need DNS + // to look it up. + name := f[1] + c.Servers = append(c.Servers, name) + } + + case "domain": // set search path to just this domain + if len(f) > 1 { + c.Search = make([]string, 1) + c.Search[0] = f[1] + } else { + c.Search = make([]string, 0) + } + + case "search": // set search path to given servers + c.Search = make([]string, len(f)-1) + for i := 0; i < len(c.Search); i++ { + c.Search[i] = f[i+1] + } + + case "options": // magic options + for i := 1; i < len(f); i++ { + s := f[i] + switch { + case len(s) >= 6 && s[:6] == "ndots:": + n, _ := strconv.Atoi(s[6:]) + if n < 1 { + n = 1 + } + c.Ndots = n + case len(s) >= 8 && s[:8] == "timeout:": + n, _ := strconv.Atoi(s[8:]) + if n < 1 { + n = 1 + } + c.Timeout = n + case len(s) >= 8 && s[:9] == "attempts:": + n, _ := strconv.Atoi(s[9:]) + if n < 1 { + n = 1 + } + c.Attempts = n + case s == "rotate": + /* not imp */ + } + } + } + } + return c, nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/defaults.go b/Godeps/_workspace/src/github.com/miekg/dns/defaults.go new file mode 100644 index 0000000000..f483d4a342 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/defaults.go @@ -0,0 +1,242 @@ +package dns + +import ( + "errors" + "net" + "strconv" +) + +const hexDigit = "0123456789abcdef" + +// Everything is assumed in ClassINET. + +// SetReply creates a reply message from a request message. +func (dns *Msg) SetReply(request *Msg) *Msg { + dns.Id = request.Id + dns.RecursionDesired = request.RecursionDesired // Copy rd bit + dns.Response = true + dns.Opcode = OpcodeQuery + dns.Rcode = RcodeSuccess + if len(request.Question) > 0 { + dns.Question = make([]Question, 1) + dns.Question[0] = request.Question[0] + } + return dns +} + +// SetQuestion creates a question message. +func (dns *Msg) SetQuestion(z string, t uint16) *Msg { + dns.Id = Id() + dns.RecursionDesired = true + dns.Question = make([]Question, 1) + dns.Question[0] = Question{z, t, ClassINET} + return dns +} + +// SetNotify creates a notify message. +func (dns *Msg) SetNotify(z string) *Msg { + dns.Opcode = OpcodeNotify + dns.Authoritative = true + dns.Id = Id() + dns.Question = make([]Question, 1) + dns.Question[0] = Question{z, TypeSOA, ClassINET} + return dns +} + +// SetRcode creates an error message suitable for the request. +func (dns *Msg) SetRcode(request *Msg, rcode int) *Msg { + dns.SetReply(request) + dns.Rcode = rcode + return dns +} + +// SetRcodeFormatError creates a message with FormError set. +func (dns *Msg) SetRcodeFormatError(request *Msg) *Msg { + dns.Rcode = RcodeFormatError + dns.Opcode = OpcodeQuery + dns.Response = true + dns.Authoritative = false + dns.Id = request.Id + return dns +} + +// SetUpdate makes the message a dynamic update message. It +// sets the ZONE section to: z, TypeSOA, ClassINET. +func (dns *Msg) SetUpdate(z string) *Msg { + dns.Id = Id() + dns.Response = false + dns.Opcode = OpcodeUpdate + dns.Compress = false // BIND9 cannot handle compression + dns.Question = make([]Question, 1) + dns.Question[0] = Question{z, TypeSOA, ClassINET} + return dns +} + +// SetIxfr creates message for requesting an IXFR. +func (dns *Msg) SetIxfr(z string, serial uint32) *Msg { + dns.Id = Id() + dns.Question = make([]Question, 1) + dns.Ns = make([]RR, 1) + s := new(SOA) + s.Hdr = RR_Header{z, TypeSOA, ClassINET, defaultTtl, 0} + s.Serial = serial + dns.Question[0] = Question{z, TypeIXFR, ClassINET} + dns.Ns[0] = s + return dns +} + +// SetAxfr creates message for requesting an AXFR. +func (dns *Msg) SetAxfr(z string) *Msg { + dns.Id = Id() + dns.Question = make([]Question, 1) + dns.Question[0] = Question{z, TypeAXFR, ClassINET} + return dns +} + +// SetTsig appends a TSIG RR to the message. +// This is only a skeleton TSIG RR that is added as the last RR in the +// additional section. The Tsig is calculated when the message is being send. +func (dns *Msg) SetTsig(z, algo string, fudge, timesigned int64) *Msg { + t := new(TSIG) + t.Hdr = RR_Header{z, TypeTSIG, ClassANY, 0, 0} + t.Algorithm = algo + t.Fudge = 300 + t.TimeSigned = uint64(timesigned) + t.OrigId = dns.Id + dns.Extra = append(dns.Extra, t) + return dns +} + +// SetEdns0 appends a EDNS0 OPT RR to the message. +// TSIG should always the last RR in a message. +func (dns *Msg) SetEdns0(udpsize uint16, do bool) *Msg { + e := new(OPT) + e.Hdr.Name = "." + e.Hdr.Rrtype = TypeOPT + e.SetUDPSize(udpsize) + if do { + e.SetDo() + } + dns.Extra = append(dns.Extra, e) + return dns +} + +// IsTsig checks if the message has a TSIG record as the last record +// in the additional section. It returns the TSIG record found or nil. +func (dns *Msg) IsTsig() *TSIG { + if len(dns.Extra) > 0 { + if dns.Extra[len(dns.Extra)-1].Header().Rrtype == TypeTSIG { + return dns.Extra[len(dns.Extra)-1].(*TSIG) + } + } + return nil +} + +// IsEdns0 checks if the message has a EDNS0 (OPT) record, any EDNS0 +// record in the additional section will do. It returns the OPT record +// found or nil. +func (dns *Msg) IsEdns0() *OPT { + for _, r := range dns.Extra { + if r.Header().Rrtype == TypeOPT { + return r.(*OPT) + } + } + return nil +} + +// IsDomainName checks if s is a valid domainname, it returns +// the number of labels and true, when a domain name is valid. +// Note that non fully qualified domain name is considered valid, in this case the +// last label is counted in the number of labels. +// When false is returned the number of labels is not defined. +func IsDomainName(s string) (labels int, ok bool) { + _, labels, err := packDomainName(s, nil, 0, nil, false) + return labels, err == nil +} + +// IsSubDomain checks if child is indeed a child of the parent. Both child and +// parent are *not* downcased before doing the comparison. +func IsSubDomain(parent, child string) bool { + // Entire child is contained in parent + return CompareDomainName(parent, child) == CountLabel(parent) +} + +// IsMsg sanity checks buf and returns an error if it isn't a valid DNS packet. +// The checking is performed on the binary payload. +func IsMsg(buf []byte) error { + // Header + if len(buf) < 12 { + return errors.New("dns: bad message header") + } + // Header: Opcode + + return nil +} + +// IsFqdn checks if a domain name is fully qualified. +func IsFqdn(s string) bool { + l := len(s) + if l == 0 { + return false + } + return s[l-1] == '.' +} + +// Fqdns return the fully qualified domain name from s. +// If s is already fully qualified, it behaves as the identity function. +func Fqdn(s string) string { + if IsFqdn(s) { + return s + } + return s + "." +} + +// Copied from the official Go code. + +// ReverseAddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP +// address suitable for reverse DNS (PTR) record lookups or an error if it fails +// to parse the IP address. +func ReverseAddr(addr string) (arpa string, err error) { + ip := net.ParseIP(addr) + if ip == nil { + return "", &Error{err: "unrecognized address: " + addr} + } + if ip.To4() != nil { + return strconv.Itoa(int(ip[15])) + "." + strconv.Itoa(int(ip[14])) + "." + strconv.Itoa(int(ip[13])) + "." + + strconv.Itoa(int(ip[12])) + ".in-addr.arpa.", nil + } + // Must be IPv6 + buf := make([]byte, 0, len(ip)*4+len("ip6.arpa.")) + // Add it, in reverse, to the buffer + for i := len(ip) - 1; i >= 0; i-- { + v := ip[i] + buf = append(buf, hexDigit[v&0xF]) + buf = append(buf, '.') + buf = append(buf, hexDigit[v>>4]) + buf = append(buf, '.') + } + // Append "ip6.arpa." and return (buf already has the final .) + buf = append(buf, "ip6.arpa."...) + return string(buf), nil +} + +// String returns the string representation for the type t. +func (t Type) String() string { + if t1, ok := TypeToString[uint16(t)]; ok { + return t1 + } + return "TYPE" + strconv.Itoa(int(t)) +} + +// String returns the string representation for the class c. +func (c Class) String() string { + if c1, ok := ClassToString[uint16(c)]; ok { + return c1 + } + return "CLASS" + strconv.Itoa(int(c)) +} + +// String returns the string representation for the name n. +func (n Name) String() string { + return sprintName(string(n)) +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/dns.go b/Godeps/_workspace/src/github.com/miekg/dns/dns.go new file mode 100644 index 0000000000..7540c0d5bc --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/dns.go @@ -0,0 +1,193 @@ +// Package dns implements a full featured interface to the Domain Name System. +// Server- and client-side programming is supported. +// The package allows complete control over what is send out to the DNS. The package +// API follows the less-is-more principle, by presenting a small, clean interface. +// +// The package dns supports (asynchronous) querying/replying, incoming/outgoing zone transfers, +// TSIG, EDNS0, dynamic updates, notifies and DNSSEC validation/signing. +// Note that domain names MUST be fully qualified, before sending them, unqualified +// names in a message will result in a packing failure. +// +// Resource records are native types. They are not stored in wire format. +// Basic usage pattern for creating a new resource record: +// +// r := new(dns.MX) +// r.Hdr = dns.RR_Header{Name: "miek.nl.", Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 3600} +// r.Preference = 10 +// r.Mx = "mx.miek.nl." +// +// Or directly from a string: +// +// mx, err := dns.NewRR("miek.nl. 3600 IN MX 10 mx.miek.nl.") +// +// Or when the default TTL (3600) and class (IN) suit you: +// +// mx, err := dns.NewRR("miek.nl. MX 10 mx.miek.nl.") +// +// Or even: +// +// mx, err := dns.NewRR("$ORIGIN nl.\nmiek 1H IN MX 10 mx.miek") +// +// In the DNS messages are exchanged, these messages contain resource +// records (sets). Use pattern for creating a message: +// +// m := new(dns.Msg) +// m.SetQuestion("miek.nl.", dns.TypeMX) +// +// Or when not certain if the domain name is fully qualified: +// +// m.SetQuestion(dns.Fqdn("miek.nl"), dns.TypeMX) +// +// The message m is now a message with the question section set to ask +// the MX records for the miek.nl. zone. +// +// The following is slightly more verbose, but more flexible: +// +// m1 := new(dns.Msg) +// m1.Id = dns.Id() +// m1.RecursionDesired = true +// m1.Question = make([]dns.Question, 1) +// m1.Question[0] = dns.Question{"miek.nl.", dns.TypeMX, dns.ClassINET} +// +// After creating a message it can be send. +// Basic use pattern for synchronous querying the DNS at a +// server configured on 127.0.0.1 and port 53: +// +// c := new(dns.Client) +// in, rtt, err := c.Exchange(m1, "127.0.0.1:53") +// +// Suppressing +// multiple outstanding queries (with the same question, type and class) is as easy as setting: +// +// c.SingleInflight = true +// +// If these "advanced" features are not needed, a simple UDP query can be send, +// with: +// +// in, err := dns.Exchange(m1, "127.0.0.1:53") +// +// When this functions returns you will get dns message. A dns message consists +// out of four sections. +// The question section: in.Question, the answer section: in.Answer, +// the authority section: in.Ns and the additional section: in.Extra. +// +// Each of these sections (except the Question section) contain a []RR. Basic +// use pattern for accessing the rdata of a TXT RR as the first RR in +// the Answer section: +// +// if t, ok := in.Answer[0].(*dns.TXT); ok { +// // do something with t.Txt +// } +// +// Domain Name and TXT Character String Representations +// +// Both domain names and TXT character strings are converted to presentation +// form both when unpacked and when converted to strings. +// +// For TXT character strings, tabs, carriage returns and line feeds will be +// converted to \t, \r and \n respectively. Back slashes and quotations marks +// will be escaped. Bytes below 32 and above 127 will be converted to \DDD +// form. +// +// For domain names, in addition to the above rules brackets, periods, +// spaces, semicolons and the at symbol are escaped. +package dns + +import ( + "strconv" +) + +const ( + year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits. + DefaultMsgSize = 4096 // Standard default for larger than 512 bytes. + MinMsgSize = 512 // Minimal size of a DNS packet. + MaxMsgSize = 65536 // Largest possible DNS packet. + defaultTtl = 3600 // Default TTL. +) + +// Error represents a DNS error +type Error struct{ err string } + +func (e *Error) Error() string { + if e == nil { + return "dns: " + } + return "dns: " + e.err +} + +// An RR represents a resource record. +type RR interface { + // Header returns the header of an resource record. The header contains + // everything up to the rdata. + Header() *RR_Header + // String returns the text representation of the resource record. + String() string + // copy returns a copy of the RR + copy() RR + // len returns the length (in octects) of the uncompressed RR in wire format. + len() int +} + +// DNS resource records. +// There are many types of RRs, +// but they all share the same header. +type RR_Header struct { + Name string `dns:"cdomain-name"` + Rrtype uint16 + Class uint16 + Ttl uint32 + Rdlength uint16 // length of data after header +} + +func (h *RR_Header) Header() *RR_Header { return h } + +// Just to imlement the RR interface +func (h *RR_Header) copy() RR { return nil } + +func (h *RR_Header) copyHeader() *RR_Header { + r := new(RR_Header) + r.Name = h.Name + r.Rrtype = h.Rrtype + r.Class = h.Class + r.Ttl = h.Ttl + r.Rdlength = h.Rdlength + return r +} + +func (h *RR_Header) String() string { + var s string + + if h.Rrtype == TypeOPT { + s = ";" + // and maybe other things + } + + s += sprintName(h.Name) + "\t" + s += strconv.FormatInt(int64(h.Ttl), 10) + "\t" + s += Class(h.Class).String() + "\t" + s += Type(h.Rrtype).String() + "\t" + return s +} + +func (h *RR_Header) len() int { + l := len(h.Name) + 1 + l += 10 // rrtype(2) + class(2) + ttl(4) + rdlength(2) + return l +} + +// ToRFC3597 converts a known RR to the unknown RR representation +// from RFC 3597. +func (rr *RFC3597) ToRFC3597(r RR) error { + buf := make([]byte, r.len()*2) + off, err := PackStruct(r, buf, 0) + if err != nil { + return err + } + buf = buf[:off] + rawSetRdlength(buf, 0, off) + _, err = UnpackStruct(rr, buf, 0) + if err != nil { + return err + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/dns_test.go b/Godeps/_workspace/src/github.com/miekg/dns/dns_test.go new file mode 100644 index 0000000000..b3063c90dd --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/dns_test.go @@ -0,0 +1,501 @@ +package dns + +import ( + "encoding/hex" + "net" + "testing" +) + +func TestPackUnpack(t *testing.T) { + out := new(Msg) + out.Answer = make([]RR, 1) + key := new(DNSKEY) + key = &DNSKEY{Flags: 257, Protocol: 3, Algorithm: RSASHA1} + key.Hdr = RR_Header{Name: "miek.nl.", Rrtype: TypeDNSKEY, Class: ClassINET, Ttl: 3600} + key.PublicKey = "AwEAAaHIwpx3w4VHKi6i1LHnTaWeHCL154Jug0Rtc9ji5qwPXpBo6A5sRv7cSsPQKPIwxLpyCrbJ4mr2L0EPOdvP6z6YfljK2ZmTbogU9aSU2fiq/4wjxbdkLyoDVgtO+JsxNN4bjr4WcWhsmk1Hg93FV9ZpkWb0Tbad8DFqNDzr//kZ" + + out.Answer[0] = key + msg, err := out.Pack() + if err != nil { + t.Log("failed to pack msg with DNSKEY") + t.Fail() + } + in := new(Msg) + if in.Unpack(msg) != nil { + t.Log("failed to unpack msg with DNSKEY") + t.Fail() + } + + sig := new(RRSIG) + sig = &RRSIG{TypeCovered: TypeDNSKEY, Algorithm: RSASHA1, Labels: 2, + OrigTtl: 3600, Expiration: 4000, Inception: 4000, KeyTag: 34641, SignerName: "miek.nl.", + Signature: "AwEAAaHIwpx3w4VHKi6i1LHnTaWeHCL154Jug0Rtc9ji5qwPXpBo6A5sRv7cSsPQKPIwxLpyCrbJ4mr2L0EPOdvP6z6YfljK2ZmTbogU9aSU2fiq/4wjxbdkLyoDVgtO+JsxNN4bjr4WcWhsmk1Hg93FV9ZpkWb0Tbad8DFqNDzr//kZ"} + sig.Hdr = RR_Header{Name: "miek.nl.", Rrtype: TypeRRSIG, Class: ClassINET, Ttl: 3600} + + out.Answer[0] = sig + msg, err = out.Pack() + if err != nil { + t.Log("failed to pack msg with RRSIG") + t.Fail() + } + + if in.Unpack(msg) != nil { + t.Log("failed to unpack msg with RRSIG") + t.Fail() + } +} + +func TestPackUnpack2(t *testing.T) { + m := new(Msg) + m.Extra = make([]RR, 1) + m.Answer = make([]RR, 1) + dom := "miek.nl." + rr := new(A) + rr.Hdr = RR_Header{Name: dom, Rrtype: TypeA, Class: ClassINET, Ttl: 0} + rr.A = net.IPv4(127, 0, 0, 1) + + x := new(TXT) + x.Hdr = RR_Header{Name: dom, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0} + x.Txt = []string{"heelalaollo"} + + m.Extra[0] = x + m.Answer[0] = rr + _, err := m.Pack() + if err != nil { + t.Log("Packing failed: " + err.Error()) + t.Fail() + return + } +} + +func TestPackUnpack3(t *testing.T) { + m := new(Msg) + m.Extra = make([]RR, 2) + m.Answer = make([]RR, 1) + dom := "miek.nl." + rr := new(A) + rr.Hdr = RR_Header{Name: dom, Rrtype: TypeA, Class: ClassINET, Ttl: 0} + rr.A = net.IPv4(127, 0, 0, 1) + + x1 := new(TXT) + x1.Hdr = RR_Header{Name: dom, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0} + x1.Txt = []string{} + + x2 := new(TXT) + x2.Hdr = RR_Header{Name: dom, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0} + x2.Txt = []string{"heelalaollo"} + + m.Extra[0] = x1 + m.Extra[1] = x2 + m.Answer[0] = rr + b, err := m.Pack() + if err != nil { + t.Log("packing failed: " + err.Error()) + t.Fail() + return + } + + var unpackMsg Msg + err = unpackMsg.Unpack(b) + if err != nil { + t.Log("unpacking failed") + t.Fail() + return + } +} + +func TestBailiwick(t *testing.T) { + yes := map[string]string{ + "miek.nl": "ns.miek.nl", + ".": "miek.nl", + } + for parent, child := range yes { + if !IsSubDomain(parent, child) { + t.Logf("%s should be child of %s\n", child, parent) + t.Logf("comparelabels %d", CompareDomainName(parent, child)) + t.Logf("lenlabels %d %d", CountLabel(parent), CountLabel(child)) + t.Fail() + } + } + no := map[string]string{ + "www.miek.nl": "ns.miek.nl", + "m\\.iek.nl": "ns.miek.nl", + "w\\.iek.nl": "w.iek.nl", + "p\\\\.iek.nl": "ns.p.iek.nl", // p\\.iek.nl , literal \ in domain name + "miek.nl": ".", + } + for parent, child := range no { + if IsSubDomain(parent, child) { + t.Logf("%s should not be child of %s\n", child, parent) + t.Logf("comparelabels %d", CompareDomainName(parent, child)) + t.Logf("lenlabels %d %d", CountLabel(parent), CountLabel(child)) + t.Fail() + } + } +} + +func TestPack(t *testing.T) { + rr := []string{"US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534"} + m := new(Msg) + var err error + m.Answer = make([]RR, 1) + for _, r := range rr { + m.Answer[0], err = NewRR(r) + if err != nil { + t.Logf("failed to create RR: %s\n", err.Error()) + t.Fail() + continue + } + if _, err := m.Pack(); err != nil { + t.Logf("packing failed: %s\n", err.Error()) + t.Fail() + } + } + x := new(Msg) + ns, _ := NewRR("pool.ntp.org. 390 IN NS a.ntpns.org") + ns.(*NS).Ns = "a.ntpns.org" + x.Ns = append(m.Ns, ns) + x.Ns = append(m.Ns, ns) + x.Ns = append(m.Ns, ns) + // This crashes due to the fact the a.ntpns.org isn't a FQDN + // How to recover() from a remove panic()? + if _, err := x.Pack(); err == nil { + t.Log("packing should fail") + t.Fail() + } + x.Answer = make([]RR, 1) + x.Answer[0], err = NewRR(rr[0]) + if _, err := x.Pack(); err == nil { + t.Log("packing should fail") + t.Fail() + } + x.Question = make([]Question, 1) + x.Question[0] = Question{";sd#edddds鍛↙赏‘℅∥↙xzztsestxssweewwsssstx@s@Z嵌e@cn.pool.ntp.org.", TypeA, ClassINET} + if _, err := x.Pack(); err == nil { + t.Log("packing should fail") + t.Fail() + } +} + +func TestPackNAPTR(t *testing.T) { + for _, n := range []string{ + `apple.com. IN NAPTR 100 50 "se" "SIP+D2U" "" _sip._udp.apple.com.`, + `apple.com. IN NAPTR 90 50 "se" "SIP+D2T" "" _sip._tcp.apple.com.`, + `apple.com. IN NAPTR 50 50 "se" "SIPS+D2T" "" _sips._tcp.apple.com.`, + } { + rr, _ := NewRR(n) + msg := make([]byte, rr.len()) + if off, err := PackRR(rr, msg, 0, nil, false); err != nil { + t.Logf("packing failed: %s", err.Error()) + t.Logf("length %d, need more than %d\n", rr.len(), off) + t.Fail() + } else { + t.Logf("buf size needed: %d\n", off) + } + } +} + +func TestCompressLength(t *testing.T) { + m := new(Msg) + m.SetQuestion("miek.nl", TypeMX) + ul := m.Len() + m.Compress = true + if ul != m.Len() { + t.Fatalf("should be equal") + } +} + +// Does the predicted length match final packed length? +func TestMsgCompressLength(t *testing.T) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + + name1 := "12345678901234567890123456789012345.12345678.123." + rrA, _ := NewRR(name1 + " 3600 IN A 192.0.2.1") + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + tests := []*Msg{ + makeMsg(name1, []RR{rrA}, nil, nil), + makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)} + + for _, msg := range tests { + predicted := msg.Len() + buf, err := msg.Pack() + if err != nil { + t.Error(err) + t.Fail() + } + if predicted < len(buf) { + t.Errorf("predicted compressed length is wrong: predicted %s (len=%d) %d, actual %d\n", + msg.Question[0].Name, len(msg.Answer), predicted, len(buf)) + t.Fail() + } + } +} + +func TestMsgLength(t *testing.T) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + return msg + } + + name1 := "12345678901234567890123456789012345.12345678.123." + rrA, _ := NewRR(name1 + " 3600 IN A 192.0.2.1") + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + tests := []*Msg{ + makeMsg(name1, []RR{rrA}, nil, nil), + makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)} + + for _, msg := range tests { + predicted := msg.Len() + buf, err := msg.Pack() + if err != nil { + t.Error(err) + t.Fail() + } + if predicted < len(buf) { + t.Errorf("predicted length is wrong: predicted %s (len=%d), actual %d\n", + msg.Question[0].Name, predicted, len(buf)) + t.Fail() + } + } +} + +func TestMsgLength2(t *testing.T) { + // Serialized replies + var testMessages = []string{ + // google.com. IN A? + "064e81800001000b0004000506676f6f676c6503636f6d0000010001c00c00010001000000050004adc22986c00c00010001000000050004adc22987c00c00010001000000050004adc22988c00c00010001000000050004adc22989c00c00010001000000050004adc2298ec00c00010001000000050004adc22980c00c00010001000000050004adc22981c00c00010001000000050004adc22982c00c00010001000000050004adc22983c00c00010001000000050004adc22984c00c00010001000000050004adc22985c00c00020001000000050006036e7331c00cc00c00020001000000050006036e7332c00cc00c00020001000000050006036e7333c00cc00c00020001000000050006036e7334c00cc0d800010001000000050004d8ef200ac0ea00010001000000050004d8ef220ac0fc00010001000000050004d8ef240ac10e00010001000000050004d8ef260a0000290500000000050000", + // amazon.com. IN A? (reply has no EDNS0 record) + // TODO(miek): this one is off-by-one, need to find out why + //"6de1818000010004000a000806616d617a6f6e03636f6d0000010001c00c000100010000000500044815c2d4c00c000100010000000500044815d7e8c00c00010001000000050004b02062a6c00c00010001000000050004cdfbf236c00c000200010000000500140570646e733408756c747261646e73036f726700c00c000200010000000500150570646e733508756c747261646e7304696e666f00c00c000200010000000500160570646e733608756c747261646e7302636f02756b00c00c00020001000000050014036e7331037033310664796e656374036e657400c00c00020001000000050006036e7332c0cfc00c00020001000000050006036e7333c0cfc00c00020001000000050006036e7334c0cfc00c000200010000000500110570646e733108756c747261646e73c0dac00c000200010000000500080570646e7332c127c00c000200010000000500080570646e7333c06ec0cb00010001000000050004d04e461fc0eb00010001000000050004cc0dfa1fc0fd00010001000000050004d04e471fc10f00010001000000050004cc0dfb1fc12100010001000000050004cc4a6c01c121001c000100000005001020010502f3ff00000000000000000001c13e00010001000000050004cc4a6d01c13e001c0001000000050010261000a1101400000000000000000001", + // yahoo.com. IN A? + "fc2d81800001000300070008057961686f6f03636f6d0000010001c00c00010001000000050004628afd6dc00c00010001000000050004628bb718c00c00010001000000050004cebe242dc00c00020001000000050006036e7336c00cc00c00020001000000050006036e7338c00cc00c00020001000000050006036e7331c00cc00c00020001000000050006036e7332c00cc00c00020001000000050006036e7333c00cc00c00020001000000050006036e7334c00cc00c00020001000000050006036e7335c00cc07b0001000100000005000444b48310c08d00010001000000050004448eff10c09f00010001000000050004cb54dd35c0b100010001000000050004628a0b9dc0c30001000100000005000477a0f77cc05700010001000000050004ca2bdfaac06900010001000000050004caa568160000290500000000050000", + // microsoft.com. IN A? + "f4368180000100020005000b096d6963726f736f667403636f6d0000010001c00c0001000100000005000440040b25c00c0001000100000005000441373ac9c00c0002000100000005000e036e7331046d736674036e657400c00c00020001000000050006036e7332c04fc00c00020001000000050006036e7333c04fc00c00020001000000050006036e7334c04fc00c00020001000000050006036e7335c04fc04b000100010000000500044137253ec04b001c00010000000500102a010111200500000000000000010001c0650001000100000005000440043badc065001c00010000000500102a010111200600060000000000010001c07700010001000000050004d5c7b435c077001c00010000000500102a010111202000000000000000010001c08900010001000000050004cf2e4bfec089001c00010000000500102404f800200300000000000000010001c09b000100010000000500044137e28cc09b001c00010000000500102a010111200f000100000000000100010000290500000000050000", + // google.com. IN MX? + "724b8180000100050004000b06676f6f676c6503636f6d00000f0001c00c000f000100000005000c000a056173706d78016cc00cc00c000f0001000000050009001404616c7431c02ac00c000f0001000000050009001e04616c7432c02ac00c000f0001000000050009002804616c7433c02ac00c000f0001000000050009003204616c7434c02ac00c00020001000000050006036e7332c00cc00c00020001000000050006036e7333c00cc00c00020001000000050006036e7334c00cc00c00020001000000050006036e7331c00cc02a00010001000000050004adc2421bc02a001c00010000000500102a00145040080c01000000000000001bc04200010001000000050004adc2461bc05700010001000000050004adc2451bc06c000100010000000500044a7d8f1bc081000100010000000500044a7d191bc0ca00010001000000050004d8ef200ac09400010001000000050004d8ef220ac0a600010001000000050004d8ef240ac0b800010001000000050004d8ef260a0000290500000000050000", + // reddit.com. IN A? + "12b98180000100080000000c0672656464697403636f6d0000020001c00c0002000100000005000f046175733204616b616d036e657400c00c000200010000000500070475736534c02dc00c000200010000000500070475737733c02dc00c000200010000000500070475737735c02dc00c00020001000000050008056173696131c02dc00c00020001000000050008056173696139c02dc00c00020001000000050008056e73312d31c02dc00c0002000100000005000a076e73312d313935c02dc02800010001000000050004c30a242ec04300010001000000050004451f1d39c05600010001000000050004451f3bc7c0690001000100000005000460073240c07c000100010000000500046007fb81c090000100010000000500047c283484c090001c00010000000500102a0226f0006700000000000000000064c0a400010001000000050004c16c5b01c0a4001c000100000005001026001401000200000000000000000001c0b800010001000000050004c16c5bc3c0b8001c0001000000050010260014010002000000000000000000c30000290500000000050000", + } + + for i, hexData := range testMessages { + // we won't fail the decoding of the hex + input, _ := hex.DecodeString(hexData) + m := new(Msg) + m.Unpack(input) + //println(m.String()) + m.Compress = true + lenComp := m.Len() + b, _ := m.Pack() + pacComp := len(b) + m.Compress = false + lenUnComp := m.Len() + b, _ = m.Pack() + pacUnComp := len(b) + if pacComp+1 != lenComp { + t.Errorf("msg.Len(compressed)=%d actual=%d for test %d", lenComp, pacComp, i) + } + if pacUnComp+1 != lenUnComp { + t.Errorf("msg.Len(uncompressed)=%d actual=%d for test %d", lenUnComp, pacUnComp, i) + } + } +} + +func TestMsgLengthCompressionMalformed(t *testing.T) { + // SOA with empty hostmaster, which is illegal + soa := &SOA{Hdr: RR_Header{Name: ".", Rrtype: TypeSOA, Class: ClassINET, Ttl: 12345}, + Ns: ".", + Mbox: "", + Serial: 0, + Refresh: 28800, + Retry: 7200, + Expire: 604800, + Minttl: 60} + m := new(Msg) + m.Compress = true + m.Ns = []RR{soa} + m.Len() // Should not crash. +} + +func BenchmarkMsgLength(b *testing.B) { + b.StopTimer() + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + b.StartTimer() + for i := 0; i < b.N; i++ { + msg.Len() + } +} + +func BenchmarkMsgLengthPack(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = msg.Pack() + } +} + +func BenchmarkMsgPackBuffer(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + buf := make([]byte, 512) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = msg.PackBuffer(buf) + } +} + +func BenchmarkMsgUnpack(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + msg_buf, _ := msg.Pack() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = msg.Unpack(msg_buf) + } +} + +func BenchmarkPackDomainName(b *testing.B) { + name1 := "12345678901234567890123456789012345.12345678.123." + buf := make([]byte, len(name1)+1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackDomainName(name1, buf, 0, nil, false) + } +} + +func BenchmarkUnpackDomainName(b *testing.B) { + name1 := "12345678901234567890123456789012345.12345678.123." + buf := make([]byte, len(name1)+1) + _, _ = PackDomainName(name1, buf, 0, nil, false) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackDomainName(buf, 0) + } +} + +func TestToRFC3597(t *testing.T) { + a, _ := NewRR("miek.nl. IN A 10.0.1.1") + x := new(RFC3597) + x.ToRFC3597(a) + if x.String() != `miek.nl. 3600 IN A \# 4 0a000101` { + t.Fail() + } +} + +func TestNoRdataPack(t *testing.T) { + data := make([]byte, 1024) + for typ, fn := range typeToRR { + if typ == TypeCAA { + continue // TODO(miek): known omission + } + r := fn() + *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600} + _, e := PackRR(r, data, 0, nil, false) + if e != nil { + t.Logf("failed to pack RR with zero rdata: %s: %s\n", TypeToString[typ], e.Error()) + t.Fail() + } + } +} + +// TODO(miek): fix dns buffer too small errors this throws +func TestNoRdataUnpack(t *testing.T) { + data := make([]byte, 1024) + for typ, fn := range typeToRR { + if typ == TypeSOA || typ == TypeTSIG || typ == TypeWKS { + // SOA, TSIG will not be seen (like this) in dyn. updates? + // WKS is an bug, but...deprecated record. + continue + } + r := fn() + *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600} + off, e := PackRR(r, data, 0, nil, false) + if e != nil { + // Should always works, TestNoDataPack should have catched this + continue + } + rr, _, e := UnpackRR(data[:off], 0) + if e != nil { + t.Logf("failed to unpack RR with zero rdata: %s: %s\n", TypeToString[typ], e.Error()) + t.Fail() + } + t.Logf("%s\n", rr) + } +} + +func TestRdataOverflow(t *testing.T) { + rr := new(RFC3597) + rr.Hdr.Name = "." + rr.Hdr.Class = ClassINET + rr.Hdr.Rrtype = 65280 + rr.Rdata = hex.EncodeToString(make([]byte, 0xFFFF)) + buf := make([]byte, 0xFFFF*2) + if _, err := PackRR(rr, buf, 0, nil, false); err != nil { + t.Fatalf("maximum size rrdata pack failed: %v", err) + } + rr.Rdata += "00" + if _, err := PackRR(rr, buf, 0, nil, false); err != ErrRdata { + t.Fatalf("oversize rrdata pack didn't return ErrRdata - instead: %v", err) + } +} + +func TestCopy(t *testing.T) { + rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1") // Weird TTL to avoid catching TTL + rr1 := Copy(rr) + if rr.String() != rr1.String() { + t.Fatalf("Copy() failed %s != %s", rr.String(), rr1.String()) + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/dnssec.go b/Godeps/_workspace/src/github.com/miekg/dns/dnssec.go new file mode 100644 index 0000000000..db32fe5544 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/dnssec.go @@ -0,0 +1,739 @@ +// DNSSEC +// +// DNSSEC (DNS Security Extension) adds a layer of security to the DNS. It +// uses public key cryptography to sign resource records. The +// public keys are stored in DNSKEY records and the signatures in RRSIG records. +// +// Requesting DNSSEC information for a zone is done by adding the DO (DNSSEC OK) bit +// to an request. +// +// m := new(dns.Msg) +// m.SetEdns0(4096, true) +// +// Signature generation, signature verification and key generation are all supported. +package dns + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "hash" + "io" + "math/big" + "sort" + "strings" + "time" +) + +// DNSSEC encryption algorithm codes. +const ( + RSAMD5 = 1 + DH = 2 + DSA = 3 + ECC = 4 + RSASHA1 = 5 + DSANSEC3SHA1 = 6 + RSASHA1NSEC3SHA1 = 7 + RSASHA256 = 8 + RSASHA512 = 10 + ECCGOST = 12 + ECDSAP256SHA256 = 13 + ECDSAP384SHA384 = 14 + INDIRECT = 252 + PRIVATEDNS = 253 // Private (experimental keys) + PRIVATEOID = 254 +) + +// DNSSEC hashing algorithm codes. +const ( + _ = iota + SHA1 // RFC 4034 + SHA256 // RFC 4509 + GOST94 // RFC 5933 + SHA384 // Experimental + SHA512 // Experimental +) + +// DNSKEY flag values. +const ( + SEP = 1 + REVOKE = 1 << 7 + ZONE = 1 << 8 +) + +// The RRSIG needs to be converted to wireformat with some of +// the rdata (the signature) missing. Use this struct to easy +// the conversion (and re-use the pack/unpack functions). +type rrsigWireFmt struct { + TypeCovered uint16 + Algorithm uint8 + Labels uint8 + OrigTtl uint32 + Expiration uint32 + Inception uint32 + KeyTag uint16 + SignerName string `dns:"domain-name"` + /* No Signature */ +} + +// Used for converting DNSKEY's rdata to wirefmt. +type dnskeyWireFmt struct { + Flags uint16 + Protocol uint8 + Algorithm uint8 + PublicKey string `dns:"base64"` + /* Nothing is left out */ +} + +// KeyTag calculates the keytag (or key-id) of the DNSKEY. +func (k *DNSKEY) KeyTag() uint16 { + if k == nil { + return 0 + } + var keytag int + switch k.Algorithm { + case RSAMD5: + // Look at the bottom two bytes of the modules, which the last + // item in the pubkey. We could do this faster by looking directly + // at the base64 values. But I'm lazy. + modulus, _ := packBase64([]byte(k.PublicKey)) + if len(modulus) > 1 { + x, _ := unpackUint16(modulus, len(modulus)-2) + keytag = int(x) + } + default: + keywire := new(dnskeyWireFmt) + keywire.Flags = k.Flags + keywire.Protocol = k.Protocol + keywire.Algorithm = k.Algorithm + keywire.PublicKey = k.PublicKey + wire := make([]byte, DefaultMsgSize) + n, err := PackStruct(keywire, wire, 0) + if err != nil { + return 0 + } + wire = wire[:n] + for i, v := range wire { + if i&1 != 0 { + keytag += int(v) // must be larger than uint32 + } else { + keytag += int(v) << 8 + } + } + keytag += (keytag >> 16) & 0xFFFF + keytag &= 0xFFFF + } + return uint16(keytag) +} + +// ToDS converts a DNSKEY record to a DS record. +func (k *DNSKEY) ToDS(h int) *DS { + if k == nil { + return nil + } + ds := new(DS) + ds.Hdr.Name = k.Hdr.Name + ds.Hdr.Class = k.Hdr.Class + ds.Hdr.Rrtype = TypeDS + ds.Hdr.Ttl = k.Hdr.Ttl + ds.Algorithm = k.Algorithm + ds.DigestType = uint8(h) + ds.KeyTag = k.KeyTag() + + keywire := new(dnskeyWireFmt) + keywire.Flags = k.Flags + keywire.Protocol = k.Protocol + keywire.Algorithm = k.Algorithm + keywire.PublicKey = k.PublicKey + wire := make([]byte, DefaultMsgSize) + n, err := PackStruct(keywire, wire, 0) + if err != nil { + return nil + } + wire = wire[:n] + + owner := make([]byte, 255) + off, err1 := PackDomainName(strings.ToLower(k.Hdr.Name), owner, 0, nil, false) + if err1 != nil { + return nil + } + owner = owner[:off] + // RFC4034: + // digest = digest_algorithm( DNSKEY owner name | DNSKEY RDATA); + // "|" denotes concatenation + // DNSKEY RDATA = Flags | Protocol | Algorithm | Public Key. + + // digest buffer + digest := append(owner, wire...) // another copy + + switch h { + case SHA1: + s := sha1.New() + io.WriteString(s, string(digest)) + ds.Digest = hex.EncodeToString(s.Sum(nil)) + case SHA256: + s := sha256.New() + io.WriteString(s, string(digest)) + ds.Digest = hex.EncodeToString(s.Sum(nil)) + case SHA384: + s := sha512.New384() + io.WriteString(s, string(digest)) + ds.Digest = hex.EncodeToString(s.Sum(nil)) + case GOST94: + /* I have no clue */ + default: + return nil + } + return ds +} + +// Sign signs an RRSet. The signature needs to be filled in with +// the values: Inception, Expiration, KeyTag, SignerName and Algorithm. +// The rest is copied from the RRset. Sign returns true when the signing went OK, +// otherwise false. +// There is no check if RRSet is a proper (RFC 2181) RRSet. +// If OrigTTL is non zero, it is used as-is, otherwise the TTL of the RRset +// is used as the OrigTTL. +func (rr *RRSIG) Sign(k PrivateKey, rrset []RR) error { + if k == nil { + return ErrPrivKey + } + // s.Inception and s.Expiration may be 0 (rollover etc.), the rest must be set + if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 { + return ErrKey + } + + rr.Hdr.Rrtype = TypeRRSIG + rr.Hdr.Name = rrset[0].Header().Name + rr.Hdr.Class = rrset[0].Header().Class + if rr.OrigTtl == 0 { // If set don't override + rr.OrigTtl = rrset[0].Header().Ttl + } + rr.TypeCovered = rrset[0].Header().Rrtype + rr.Labels = uint8(CountLabel(rrset[0].Header().Name)) + + if strings.HasPrefix(rrset[0].Header().Name, "*") { + rr.Labels-- // wildcard, remove from label count + } + + sigwire := new(rrsigWireFmt) + sigwire.TypeCovered = rr.TypeCovered + sigwire.Algorithm = rr.Algorithm + sigwire.Labels = rr.Labels + sigwire.OrigTtl = rr.OrigTtl + sigwire.Expiration = rr.Expiration + sigwire.Inception = rr.Inception + sigwire.KeyTag = rr.KeyTag + // For signing, lowercase this name + sigwire.SignerName = strings.ToLower(rr.SignerName) + + // Create the desired binary blob + signdata := make([]byte, DefaultMsgSize) + n, err := PackStruct(sigwire, signdata, 0) + if err != nil { + return err + } + signdata = signdata[:n] + wire, err := rawSignatureData(rrset, rr) + if err != nil { + return err + } + signdata = append(signdata, wire...) + + var sighash []byte + var h hash.Hash + var ch crypto.Hash // Only need for RSA + switch rr.Algorithm { + case DSA, DSANSEC3SHA1: + // Implicit in the ParameterSizes + case RSASHA1, RSASHA1NSEC3SHA1: + h = sha1.New() + ch = crypto.SHA1 + case RSASHA256, ECDSAP256SHA256: + h = sha256.New() + ch = crypto.SHA256 + case ECDSAP384SHA384: + h = sha512.New384() + case RSASHA512: + h = sha512.New() + ch = crypto.SHA512 + case RSAMD5: + fallthrough // Deprecated in RFC 6725 + default: + return ErrAlg + } + io.WriteString(h, string(signdata)) + sighash = h.Sum(nil) + + switch p := k.(type) { + case *dsa.PrivateKey: + r1, s1, err := dsa.Sign(rand.Reader, p, sighash) + if err != nil { + return err + } + signature := []byte{0x4D} // T value, here the ASCII M for Miek (not used in DNSSEC) + signature = append(signature, r1.Bytes()...) + signature = append(signature, s1.Bytes()...) + rr.Signature = unpackBase64(signature) + case *rsa.PrivateKey: + // We can use nil as rand.Reader here (says AGL) + signature, err := rsa.SignPKCS1v15(nil, p, ch, sighash) + if err != nil { + return err + } + rr.Signature = unpackBase64(signature) + case *ecdsa.PrivateKey: + r1, s1, err := ecdsa.Sign(rand.Reader, p, sighash) + if err != nil { + return err + } + signature := r1.Bytes() + signature = append(signature, s1.Bytes()...) + rr.Signature = unpackBase64(signature) + default: + // Not given the correct key + return ErrKeyAlg + } + return nil +} + +// Verify validates an RRSet with the signature and key. This is only the +// cryptographic test, the signature validity period must be checked separately. +// This function copies the rdata of some RRs (to lowercase domain names) for the validation to work. +func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error { + // First the easy checks + if len(rrset) == 0 { + return ErrRRset + } + if rr.KeyTag != k.KeyTag() { + return ErrKey + } + if rr.Hdr.Class != k.Hdr.Class { + return ErrKey + } + if rr.Algorithm != k.Algorithm { + return ErrKey + } + if strings.ToLower(rr.SignerName) != strings.ToLower(k.Hdr.Name) { + return ErrKey + } + if k.Protocol != 3 { + return ErrKey + } + for _, r := range rrset { + if r.Header().Class != rr.Hdr.Class { + return ErrRRset + } + if r.Header().Rrtype != rr.TypeCovered { + return ErrRRset + } + } + // RFC 4035 5.3.2. Reconstructing the Signed Data + // Copy the sig, except the rrsig data + sigwire := new(rrsigWireFmt) + sigwire.TypeCovered = rr.TypeCovered + sigwire.Algorithm = rr.Algorithm + sigwire.Labels = rr.Labels + sigwire.OrigTtl = rr.OrigTtl + sigwire.Expiration = rr.Expiration + sigwire.Inception = rr.Inception + sigwire.KeyTag = rr.KeyTag + sigwire.SignerName = strings.ToLower(rr.SignerName) + // Create the desired binary blob + signeddata := make([]byte, DefaultMsgSize) + n, err := PackStruct(sigwire, signeddata, 0) + if err != nil { + return err + } + signeddata = signeddata[:n] + wire, err := rawSignatureData(rrset, rr) + if err != nil { + return err + } + signeddata = append(signeddata, wire...) + + sigbuf := rr.sigBuf() // Get the binary signature data + if rr.Algorithm == PRIVATEDNS { // PRIVATEOID + // TODO(mg) + // remove the domain name and assume its our + } + + switch rr.Algorithm { + case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512, RSAMD5: + // TODO(mg): this can be done quicker, ie. cache the pubkey data somewhere?? + pubkey := k.publicKeyRSA() // Get the key + if pubkey == nil { + return ErrKey + } + // Setup the hash as defined for this alg. + var h hash.Hash + var ch crypto.Hash + switch rr.Algorithm { + case RSAMD5: + h = md5.New() + ch = crypto.MD5 + case RSASHA1, RSASHA1NSEC3SHA1: + h = sha1.New() + ch = crypto.SHA1 + case RSASHA256: + h = sha256.New() + ch = crypto.SHA256 + case RSASHA512: + h = sha512.New() + ch = crypto.SHA512 + } + io.WriteString(h, string(signeddata)) + sighash := h.Sum(nil) + return rsa.VerifyPKCS1v15(pubkey, ch, sighash, sigbuf) + case ECDSAP256SHA256, ECDSAP384SHA384: + pubkey := k.publicKeyCurve() + if pubkey == nil { + return ErrKey + } + var h hash.Hash + switch rr.Algorithm { + case ECDSAP256SHA256: + h = sha256.New() + case ECDSAP384SHA384: + h = sha512.New384() + } + io.WriteString(h, string(signeddata)) + sighash := h.Sum(nil) + // Split sigbuf into the r and s coordinates + r := big.NewInt(0) + r.SetBytes(sigbuf[:len(sigbuf)/2]) + s := big.NewInt(0) + s.SetBytes(sigbuf[len(sigbuf)/2:]) + if ecdsa.Verify(pubkey, sighash, r, s) { + return nil + } + return ErrSig + } + // Unknown alg + return ErrAlg +} + +// ValidityPeriod uses RFC1982 serial arithmetic to calculate +// if a signature period is valid. If t is the zero time, the +// current time is taken other t is. +func (rr *RRSIG) ValidityPeriod(t time.Time) bool { + var utc int64 + if t.IsZero() { + utc = time.Now().UTC().Unix() + } else { + utc = t.UTC().Unix() + } + modi := (int64(rr.Inception) - utc) / year68 + mode := (int64(rr.Expiration) - utc) / year68 + ti := int64(rr.Inception) + (modi * year68) + te := int64(rr.Expiration) + (mode * year68) + return ti <= utc && utc <= te +} + +// Return the signatures base64 encodedig sigdata as a byte slice. +func (s *RRSIG) sigBuf() []byte { + sigbuf, err := packBase64([]byte(s.Signature)) + if err != nil { + return nil + } + return sigbuf +} + +// setPublicKeyInPrivate sets the public key in the private key. +func (k *DNSKEY) setPublicKeyInPrivate(p PrivateKey) bool { + switch t := p.(type) { + case *dsa.PrivateKey: + x := k.publicKeyDSA() + if x == nil { + return false + } + t.PublicKey = *x + case *rsa.PrivateKey: + x := k.publicKeyRSA() + if x == nil { + return false + } + t.PublicKey = *x + case *ecdsa.PrivateKey: + x := k.publicKeyCurve() + if x == nil { + return false + } + t.PublicKey = *x + } + return true +} + +// publicKeyRSA returns the RSA public key from a DNSKEY record. +func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey { + keybuf, err := packBase64([]byte(k.PublicKey)) + if err != nil { + return nil + } + + // RFC 2537/3110, section 2. RSA Public KEY Resource Records + // Length is in the 0th byte, unless its zero, then it + // it in bytes 1 and 2 and its a 16 bit number + explen := uint16(keybuf[0]) + keyoff := 1 + if explen == 0 { + explen = uint16(keybuf[1])<<8 | uint16(keybuf[2]) + keyoff = 3 + } + pubkey := new(rsa.PublicKey) + + pubkey.N = big.NewInt(0) + shift := uint64((explen - 1) * 8) + expo := uint64(0) + for i := int(explen - 1); i > 0; i-- { + expo += uint64(keybuf[keyoff+i]) << shift + shift -= 8 + } + // Remainder + expo += uint64(keybuf[keyoff]) + if expo > 2<<31 { + // Larger expo than supported. + // println("dns: F5 primes (or larger) are not supported") + return nil + } + pubkey.E = int(expo) + + pubkey.N.SetBytes(keybuf[keyoff+int(explen):]) + return pubkey +} + +// publicKeyCurve returns the Curve public key from the DNSKEY record. +func (k *DNSKEY) publicKeyCurve() *ecdsa.PublicKey { + keybuf, err := packBase64([]byte(k.PublicKey)) + if err != nil { + return nil + } + pubkey := new(ecdsa.PublicKey) + switch k.Algorithm { + case ECDSAP256SHA256: + pubkey.Curve = elliptic.P256() + if len(keybuf) != 64 { + // wrongly encoded key + return nil + } + case ECDSAP384SHA384: + pubkey.Curve = elliptic.P384() + if len(keybuf) != 96 { + // Wrongly encoded key + return nil + } + } + pubkey.X = big.NewInt(0) + pubkey.X.SetBytes(keybuf[:len(keybuf)/2]) + pubkey.Y = big.NewInt(0) + pubkey.Y.SetBytes(keybuf[len(keybuf)/2:]) + return pubkey +} + +func (k *DNSKEY) publicKeyDSA() *dsa.PublicKey { + keybuf, err := packBase64([]byte(k.PublicKey)) + if err != nil { + return nil + } + if len(keybuf) < 22 { // TODO: check + return nil + } + t := int(keybuf[0]) + size := 64 + t*8 + pubkey := new(dsa.PublicKey) + pubkey.Parameters.Q = big.NewInt(0) + pubkey.Parameters.Q.SetBytes(keybuf[1:21]) // +/- 1 ? + pubkey.Parameters.P = big.NewInt(0) + pubkey.Parameters.P.SetBytes(keybuf[22 : 22+size]) + pubkey.Parameters.G = big.NewInt(0) + pubkey.Parameters.G.SetBytes(keybuf[22+size+1 : 22+size*2]) + pubkey.Y = big.NewInt(0) + pubkey.Y.SetBytes(keybuf[22+size*2+1 : 22+size*3]) + return pubkey +} + +// Set the public key (the value E and N) +func (k *DNSKEY) setPublicKeyRSA(_E int, _N *big.Int) bool { + if _E == 0 || _N == nil { + return false + } + buf := exponentToBuf(_E) + buf = append(buf, _N.Bytes()...) + k.PublicKey = unpackBase64(buf) + return true +} + +// Set the public key for Elliptic Curves +func (k *DNSKEY) setPublicKeyCurve(_X, _Y *big.Int) bool { + if _X == nil || _Y == nil { + return false + } + buf := curveToBuf(_X, _Y) + // Check the length of the buffer, either 64 or 92 bytes + k.PublicKey = unpackBase64(buf) + return true +} + +// Set the public key for DSA +func (k *DNSKEY) setPublicKeyDSA(_Q, _P, _G, _Y *big.Int) bool { + if _Q == nil || _P == nil || _G == nil || _Y == nil { + return false + } + buf := dsaToBuf(_Q, _P, _G, _Y) + k.PublicKey = unpackBase64(buf) + return true +} + +// Set the public key (the values E and N) for RSA +// RFC 3110: Section 2. RSA Public KEY Resource Records +func exponentToBuf(_E int) []byte { + var buf []byte + i := big.NewInt(int64(_E)) + if len(i.Bytes()) < 256 { + buf = make([]byte, 1) + buf[0] = uint8(len(i.Bytes())) + } else { + buf = make([]byte, 3) + buf[0] = 0 + buf[1] = uint8(len(i.Bytes()) >> 8) + buf[2] = uint8(len(i.Bytes())) + } + buf = append(buf, i.Bytes()...) + return buf +} + +// Set the public key for X and Y for Curve. The two +// values are just concatenated. +func curveToBuf(_X, _Y *big.Int) []byte { + buf := _X.Bytes() + buf = append(buf, _Y.Bytes()...) + return buf +} + +// Set the public key for X and Y for Curve. The two +// values are just concatenated. +func dsaToBuf(_Q, _P, _G, _Y *big.Int) []byte { + t := byte((len(_G.Bytes()) - 64) / 8) + buf := []byte{t} + buf = append(buf, _Q.Bytes()...) + buf = append(buf, _P.Bytes()...) + buf = append(buf, _G.Bytes()...) + buf = append(buf, _Y.Bytes()...) + return buf +} + +type wireSlice [][]byte + +func (p wireSlice) Len() int { return len(p) } +func (p wireSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p wireSlice) Less(i, j int) bool { + _, ioff, _ := UnpackDomainName(p[i], 0) + _, joff, _ := UnpackDomainName(p[j], 0) + return bytes.Compare(p[i][ioff+10:], p[j][joff+10:]) < 0 +} + +// Return the raw signature data. +func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) { + wires := make(wireSlice, len(rrset)) + for i, r := range rrset { + r1 := r.copy() + r1.Header().Ttl = s.OrigTtl + labels := SplitDomainName(r1.Header().Name) + // 6.2. Canonical RR Form. (4) - wildcards + if len(labels) > int(s.Labels) { + // Wildcard + r1.Header().Name = "*." + strings.Join(labels[len(labels)-int(s.Labels):], ".") + "." + } + // RFC 4034: 6.2. Canonical RR Form. (2) - domain name to lowercase + r1.Header().Name = strings.ToLower(r1.Header().Name) + // 6.2. Canonical RR Form. (3) - domain rdata to lowercase. + // NS, MD, MF, CNAME, SOA, MB, MG, MR, PTR, + // HINFO, MINFO, MX, RP, AFSDB, RT, SIG, PX, NXT, NAPTR, KX, + // SRV, DNAME, A6 + switch x := r1.(type) { + case *NS: + x.Ns = strings.ToLower(x.Ns) + case *CNAME: + x.Target = strings.ToLower(x.Target) + case *SOA: + x.Ns = strings.ToLower(x.Ns) + x.Mbox = strings.ToLower(x.Mbox) + case *MB: + x.Mb = strings.ToLower(x.Mb) + case *MG: + x.Mg = strings.ToLower(x.Mg) + case *MR: + x.Mr = strings.ToLower(x.Mr) + case *PTR: + x.Ptr = strings.ToLower(x.Ptr) + case *MINFO: + x.Rmail = strings.ToLower(x.Rmail) + x.Email = strings.ToLower(x.Email) + case *MX: + x.Mx = strings.ToLower(x.Mx) + case *NAPTR: + x.Replacement = strings.ToLower(x.Replacement) + case *KX: + x.Exchanger = strings.ToLower(x.Exchanger) + case *SRV: + x.Target = strings.ToLower(x.Target) + case *DNAME: + x.Target = strings.ToLower(x.Target) + } + // 6.2. Canonical RR Form. (5) - origTTL + wire := make([]byte, r1.len()+1) // +1 to be safe(r) + off, err1 := PackRR(r1, wire, 0, nil, false) + if err1 != nil { + return nil, err1 + } + wire = wire[:off] + wires[i] = wire + } + sort.Sort(wires) + for _, wire := range wires { + buf = append(buf, wire...) + } + return buf, nil +} + +// Map for algorithm names. +var AlgorithmToString = map[uint8]string{ + RSAMD5: "RSAMD5", + DH: "DH", + DSA: "DSA", + RSASHA1: "RSASHA1", + DSANSEC3SHA1: "DSA-NSEC3-SHA1", + RSASHA1NSEC3SHA1: "RSASHA1-NSEC3-SHA1", + RSASHA256: "RSASHA256", + RSASHA512: "RSASHA512", + ECCGOST: "ECC-GOST", + ECDSAP256SHA256: "ECDSAP256SHA256", + ECDSAP384SHA384: "ECDSAP384SHA384", + INDIRECT: "INDIRECT", + PRIVATEDNS: "PRIVATEDNS", + PRIVATEOID: "PRIVATEOID", +} + +// Map of algorithm strings. +var StringToAlgorithm = reverseInt8(AlgorithmToString) + +// Map for hash names. +var HashToString = map[uint8]string{ + SHA1: "SHA1", + SHA256: "SHA256", + GOST94: "GOST94", + SHA384: "SHA384", + SHA512: "SHA512", +} + +// Map of hash strings. +var StringToHash = reverseInt8(HashToString) diff --git a/Godeps/_workspace/src/github.com/miekg/dns/dnssec_test.go b/Godeps/_workspace/src/github.com/miekg/dns/dnssec_test.go new file mode 100644 index 0000000000..484367cf1e --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/dnssec_test.go @@ -0,0 +1,646 @@ +package dns + +import ( + "crypto/rsa" + "reflect" + "strings" + "testing" + "time" +) + +func getKey() *DNSKEY { + key := new(DNSKEY) + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz" + return key +} + +func getSoa() *SOA { + soa := new(SOA) + soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0} + soa.Ns = "open.nlnetlabs.nl." + soa.Mbox = "miekg.atoom.net." + soa.Serial = 1293945905 + soa.Refresh = 14400 + soa.Retry = 3600 + soa.Expire = 604800 + soa.Minttl = 86400 + return soa +} + +func TestGenerateEC(t *testing.T) { + key := new(DNSKEY) + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = ECDSAP256SHA256 + privkey, _ := key.Generate(256) + t.Logf("%s\n", key.String()) + t.Logf("%s\n", key.PrivateKeyString(privkey)) +} + +func TestGenerateDSA(t *testing.T) { + key := new(DNSKEY) + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = DSA + privkey, _ := key.Generate(1024) + t.Logf("%s\n", key.String()) + t.Logf("%s\n", key.PrivateKeyString(privkey)) +} + +func TestGenerateRSA(t *testing.T) { + key := new(DNSKEY) + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + privkey, _ := key.Generate(1024) + t.Logf("%s\n", key.String()) + t.Logf("%s\n", key.PrivateKeyString(privkey)) +} + +func TestSecure(t *testing.T) { + soa := getSoa() + + sig := new(RRSIG) + sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0} + sig.TypeCovered = TypeSOA + sig.Algorithm = RSASHA256 + sig.Labels = 2 + sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05" + sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05" + sig.OrigTtl = 14400 + sig.KeyTag = 12051 + sig.SignerName = "miek.nl." + sig.Signature = "oMCbslaAVIp/8kVtLSms3tDABpcPRUgHLrOR48OOplkYo+8TeEGWwkSwaz/MRo2fB4FxW0qj/hTlIjUGuACSd+b1wKdH5GvzRJc2pFmxtCbm55ygAh4EUL0F6U5cKtGJGSXxxg6UFCQ0doJCmiGFa78LolaUOXImJrk6AFrGa0M=" + + key := new(DNSKEY) + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz" + + // It should validate. Period is checked seperately, so this will keep on working + if sig.Verify(key, []RR{soa}) != nil { + t.Log("failure to validate") + t.Fail() + } +} + +func TestSignature(t *testing.T) { + sig := new(RRSIG) + sig.Hdr.Name = "miek.nl." + sig.Hdr.Class = ClassINET + sig.Hdr.Ttl = 3600 + sig.TypeCovered = TypeDNSKEY + sig.Algorithm = RSASHA1 + sig.Labels = 2 + sig.OrigTtl = 4000 + sig.Expiration = 1000 //Thu Jan 1 02:06:40 CET 1970 + sig.Inception = 800 //Thu Jan 1 01:13:20 CET 1970 + sig.KeyTag = 34641 + sig.SignerName = "miek.nl." + sig.Signature = "AwEAAaHIwpx3w4VHKi6i1LHnTaWeHCL154Jug0Rtc9ji5qwPXpBo6A5sRv7cSsPQKPIwxLpyCrbJ4mr2L0EPOdvP6z6YfljK2ZmTbogU9aSU2fiq/4wjxbdkLyoDVgtO+JsxNN4bjr4WcWhsmk1Hg93FV9ZpkWb0Tbad8DFqNDzr//kZ" + + // Should not be valid + if sig.ValidityPeriod(time.Now()) { + t.Log("should not be valid") + t.Fail() + } + + sig.Inception = 315565800 //Tue Jan 1 10:10:00 CET 1980 + sig.Expiration = 4102477800 //Fri Jan 1 10:10:00 CET 2100 + if !sig.ValidityPeriod(time.Now()) { + t.Log("should be valid") + t.Fail() + } +} + +func TestSignVerify(t *testing.T) { + // The record we want to sign + soa := new(SOA) + soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0} + soa.Ns = "open.nlnetlabs.nl." + soa.Mbox = "miekg.atoom.net." + soa.Serial = 1293945905 + soa.Refresh = 14400 + soa.Retry = 3600 + soa.Expire = 604800 + soa.Minttl = 86400 + + soa1 := new(SOA) + soa1.Hdr = RR_Header{"*.miek.nl.", TypeSOA, ClassINET, 14400, 0} + soa1.Ns = "open.nlnetlabs.nl." + soa1.Mbox = "miekg.atoom.net." + soa1.Serial = 1293945905 + soa1.Refresh = 14400 + soa1.Retry = 3600 + soa1.Expire = 604800 + soa1.Minttl = 86400 + + srv := new(SRV) + srv.Hdr = RR_Header{"srv.miek.nl.", TypeSRV, ClassINET, 14400, 0} + srv.Port = 1000 + srv.Weight = 800 + srv.Target = "web1.miek.nl." + + // With this key + key := new(DNSKEY) + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + privkey, _ := key.Generate(512) + + // Fill in the values of the Sig, before signing + sig := new(RRSIG) + sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0} + sig.TypeCovered = soa.Hdr.Rrtype + sig.Labels = uint8(CountLabel(soa.Hdr.Name)) // works for all 3 + sig.OrigTtl = soa.Hdr.Ttl + sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05" + sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05" + sig.KeyTag = key.KeyTag() // Get the keyfrom the Key + sig.SignerName = key.Hdr.Name + sig.Algorithm = RSASHA256 + + for _, r := range []RR{soa, soa1, srv} { + if sig.Sign(privkey, []RR{r}) != nil { + t.Log("failure to sign the record") + t.Fail() + continue + } + if sig.Verify(key, []RR{r}) != nil { + t.Log("failure to validate") + t.Fail() + continue + } + t.Logf("validated: %s\n", r.Header().Name) + } +} + +func Test65534(t *testing.T) { + t6 := new(RFC3597) + t6.Hdr = RR_Header{"miek.nl.", 65534, ClassINET, 14400, 0} + t6.Rdata = "505D870001" + key := new(DNSKEY) + key.Hdr.Name = "miek.nl." + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + privkey, _ := key.Generate(1024) + + sig := new(RRSIG) + sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0} + sig.TypeCovered = t6.Hdr.Rrtype + sig.Labels = uint8(CountLabel(t6.Hdr.Name)) + sig.OrigTtl = t6.Hdr.Ttl + sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05" + sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05" + sig.KeyTag = key.KeyTag() + sig.SignerName = key.Hdr.Name + sig.Algorithm = RSASHA256 + if err := sig.Sign(privkey, []RR{t6}); err != nil { + t.Log(err) + t.Log("failure to sign the TYPE65534 record") + t.Fail() + } + if err := sig.Verify(key, []RR{t6}); err != nil { + t.Log(err) + t.Log("failure to validate") + t.Fail() + } else { + t.Logf("validated: %s\n", t6.Header().Name) + } +} + +func TestDnskey(t *testing.T) { + // f, _ := os.Open("t/Kmiek.nl.+010+05240.key") + pubkey, _ := ReadRR(strings.NewReader(` +miek.nl. IN DNSKEY 256 3 10 AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL ;{id = 5240 (zsk), size = 1024b} +`), "Kmiek.nl.+010+05240.key") + privkey, _ := pubkey.(*DNSKEY).ReadPrivateKey(strings.NewReader(` +Private-key-format: v1.2 +Algorithm: 10 (RSASHA512) +Modulus: m4wK7YV26AeROtdiCXmqLG9wPDVoMOW8vjr/EkpscEAdjXp81RvZvrlzCSjYmz9onFRgltmTl3AINnFh+t9tlW0M9C5zejxBoKFXELv8ljPYAdz2oe+pDWPhWsfvVFYg2VCjpViPM38EakyE5mhk4TDOnUd+w4TeU1hyhZTWyYs= +PublicExponent: AQAB +PrivateExponent: UfCoIQ/Z38l8vB6SSqOI/feGjHEl/fxIPX4euKf0D/32k30fHbSaNFrFOuIFmWMB3LimWVEs6u3dpbB9CQeCVg7hwU5puG7OtuiZJgDAhNeOnxvo5btp4XzPZrJSxR4WNQnwIiYWbl0aFlL1VGgHC/3By89ENZyWaZcMLW4KGWE= +Prime1: yxwC6ogAu8aVcDx2wg1V0b5M5P6jP8qkRFVMxWNTw60Vkn+ECvw6YAZZBHZPaMyRYZLzPgUlyYRd0cjupy4+fQ== +Prime2: xA1bF8M0RTIQ6+A11AoVG6GIR/aPGg5sogRkIZ7ID/sF6g9HMVU/CM2TqVEBJLRPp73cv6ZeC3bcqOCqZhz+pw== +Exponent1: xzkblyZ96bGYxTVZm2/vHMOXswod4KWIyMoOepK6B/ZPcZoIT6omLCgtypWtwHLfqyCz3MK51Nc0G2EGzg8rFQ== +Exponent2: Pu5+mCEb7T5F+kFNZhQadHUklt0JUHbi3hsEvVoHpEGSw3BGDQrtIflDde0/rbWHgDPM4WQY+hscd8UuTXrvLw== +Coefficient: UuRoNqe7YHnKmQzE6iDWKTMIWTuoqqrFAmXPmKQnC+Y+BQzOVEHUo9bXdDnoI9hzXP1gf8zENMYwYLeWpuYlFQ== +`), "Kmiek.nl.+010+05240.private") + if pubkey.(*DNSKEY).PublicKey != "AwEAAZuMCu2FdugHkTrXYgl5qixvcDw1aDDlvL46/xJKbHBAHY16fNUb2b65cwko2Js/aJxUYJbZk5dwCDZxYfrfbZVtDPQuc3o8QaChVxC7/JYz2AHc9qHvqQ1j4VrH71RWINlQo6VYjzN/BGpMhOZoZOEwzp1HfsOE3lNYcoWU1smL" { + t.Log("pubkey is not what we've read") + t.Fail() + } + // Coefficient looks fishy... + t.Logf("%s", pubkey.(*DNSKEY).PrivateKeyString(privkey)) +} + +func TestTag(t *testing.T) { + key := new(DNSKEY) + key.Hdr.Name = "miek.nl." + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 3600 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz" + + tag := key.KeyTag() + if tag != 12051 { + t.Logf("wrong key tag: %d for key %v\n", tag, key) + t.Fail() + } +} + +func TestKeyRSA(t *testing.T) { + key := new(DNSKEY) + key.Hdr.Name = "miek.nl." + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 3600 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + priv, _ := key.Generate(2048) + + soa := new(SOA) + soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0} + soa.Ns = "open.nlnetlabs.nl." + soa.Mbox = "miekg.atoom.net." + soa.Serial = 1293945905 + soa.Refresh = 14400 + soa.Retry = 3600 + soa.Expire = 604800 + soa.Minttl = 86400 + + sig := new(RRSIG) + sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0} + sig.TypeCovered = TypeSOA + sig.Algorithm = RSASHA256 + sig.Labels = 2 + sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05" + sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05" + sig.OrigTtl = soa.Hdr.Ttl + sig.KeyTag = key.KeyTag() + sig.SignerName = key.Hdr.Name + + if err := sig.Sign(priv, []RR{soa}); err != nil { + t.Logf("failed to sign") + t.Fail() + return + } + if err := sig.Verify(key, []RR{soa}); err != nil { + t.Logf("failed to verify") + t.Fail() + } +} + +func TestKeyToDS(t *testing.T) { + key := new(DNSKEY) + key.Hdr.Name = "miek.nl." + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 3600 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = RSASHA256 + key.PublicKey = "AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5ECIoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXHPy7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz" + + ds := key.ToDS(SHA1) + if strings.ToUpper(ds.Digest) != "B5121BDB5B8D86D0CC5FFAFBAAABE26C3E20BAC1" { + t.Logf("wrong DS digest for SHA1\n%v\n", ds) + t.Fail() + } +} + +func TestSignRSA(t *testing.T) { + pub := "miek.nl. IN DNSKEY 256 3 5 AwEAAb+8lGNCxJgLS8rYVer6EnHVuIkQDghdjdtewDzU3G5R7PbMbKVRvH2Ma7pQyYceoaqWZQirSj72euPWfPxQnMy9ucCylA+FuH9cSjIcPf4PqJfdupHk9X6EBYjxrCLY4p1/yBwgyBIRJtZtAqM3ceAH2WovEJD6rTtOuHo5AluJ" + + priv := `Private-key-format: v1.3 +Algorithm: 5 (RSASHA1) +Modulus: v7yUY0LEmAtLythV6voScdW4iRAOCF2N217APNTcblHs9sxspVG8fYxrulDJhx6hqpZlCKtKPvZ649Z8/FCczL25wLKUD4W4f1xKMhw9/g+ol926keT1foQFiPGsItjinX/IHCDIEhEm1m0Cozdx4AfZai8QkPqtO064ejkCW4k= +PublicExponent: AQAB +PrivateExponent: YPwEmwjk5HuiROKU4xzHQ6l1hG8Iiha4cKRG3P5W2b66/EN/GUh07ZSf0UiYB67o257jUDVEgwCuPJz776zfApcCB4oGV+YDyEu7Hp/rL8KcSN0la0k2r9scKwxTp4BTJT23zyBFXsV/1wRDK1A5NxsHPDMYi2SoK63Enm/1ptk= +Prime1: /wjOG+fD0ybNoSRn7nQ79udGeR1b0YhUA5mNjDx/x2fxtIXzygYk0Rhx9QFfDy6LOBvz92gbNQlzCLz3DJt5hw== +Prime2: wHZsJ8OGhkp5p3mrJFZXMDc2mbYusDVTA+t+iRPdS797Tj0pjvU2HN4vTnTj8KBQp6hmnY7dLp9Y1qserySGbw== +Exponent1: N0A7FsSRIg+IAN8YPQqlawoTtG1t1OkJ+nWrurPootScApX6iMvn8fyvw3p2k51rv84efnzpWAYiC8SUaQDNxQ== +Exponent2: SvuYRaGyvo0zemE3oS+WRm2scxR8eiA8WJGeOc+obwOKCcBgeZblXzfdHGcEC1KaOcetOwNW/vwMA46lpLzJNw== +Coefficient: 8+7ZN/JgByqv0NfULiFKTjtyegUcijRuyij7yNxYbCBneDvZGxJwKNi4YYXWx743pcAj4Oi4Oh86gcmxLs+hGw== +Created: 20110302104537 +Publish: 20110302104537 +Activate: 20110302104537` + + xk, _ := NewRR(pub) + k := xk.(*DNSKEY) + p, err := k.NewPrivateKey(priv) + if err != nil { + t.Logf("%v\n", err) + t.Fail() + } + switch priv := p.(type) { + case *rsa.PrivateKey: + if 65537 != priv.PublicKey.E { + t.Log("exponenent should be 65537") + t.Fail() + } + default: + t.Logf("we should have read an RSA key: %v", priv) + t.Fail() + } + if k.KeyTag() != 37350 { + t.Logf("%d %v\n", k.KeyTag(), k) + t.Log("keytag should be 37350") + t.Fail() + } + + soa := new(SOA) + soa.Hdr = RR_Header{"miek.nl.", TypeSOA, ClassINET, 14400, 0} + soa.Ns = "open.nlnetlabs.nl." + soa.Mbox = "miekg.atoom.net." + soa.Serial = 1293945905 + soa.Refresh = 14400 + soa.Retry = 3600 + soa.Expire = 604800 + soa.Minttl = 86400 + + sig := new(RRSIG) + sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0} + sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05" + sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05" + sig.KeyTag = k.KeyTag() + sig.SignerName = k.Hdr.Name + sig.Algorithm = k.Algorithm + + sig.Sign(p, []RR{soa}) + if sig.Signature != "D5zsobpQcmMmYsUMLxCVEtgAdCvTu8V/IEeP4EyLBjqPJmjt96bwM9kqihsccofA5LIJ7DN91qkCORjWSTwNhzCv7bMyr2o5vBZElrlpnRzlvsFIoAZCD9xg6ZY7ZyzUJmU6IcTwG4v3xEYajcpbJJiyaw/RqR90MuRdKPiBzSo=" { + t.Log("signature is not correct") + t.Logf("%v\n", sig) + t.Fail() + } +} + +func TestSignVerifyECDSA(t *testing.T) { + pub := `example.net. 3600 IN DNSKEY 257 3 14 ( + xKYaNhWdGOfJ+nPrL8/arkwf2EY3MDJ+SErKivBVSum1 + w/egsXvSADtNJhyem5RCOpgQ6K8X1DRSEkrbYQ+OB+v8 + /uX45NBwY8rp65F6Glur8I/mlVNgF6W/qTI37m40 )` + priv := `Private-key-format: v1.2 +Algorithm: 14 (ECDSAP384SHA384) +PrivateKey: WURgWHCcYIYUPWgeLmiPY2DJJk02vgrmTfitxgqcL4vwW7BOrbawVmVe0d9V94SR` + + eckey, err := NewRR(pub) + if err != nil { + t.Fatal(err.Error()) + } + privkey, err := eckey.(*DNSKEY).NewPrivateKey(priv) + if err != nil { + t.Fatal(err.Error()) + } + // TODO: Create seperate test for this + ds := eckey.(*DNSKEY).ToDS(SHA384) + if ds.KeyTag != 10771 { + t.Fatal("wrong keytag on DS") + } + if ds.Digest != "72d7b62976ce06438e9c0bf319013cf801f09ecc84b8d7e9495f27e305c6a9b0563a9b5f4d288405c3008a946df983d6" { + t.Fatal("wrong DS Digest") + } + a, _ := NewRR("www.example.net. 3600 IN A 192.0.2.1") + sig := new(RRSIG) + sig.Hdr = RR_Header{"example.net.", TypeRRSIG, ClassINET, 14400, 0} + sig.Expiration, _ = StringToTime("20100909102025") + sig.Inception, _ = StringToTime("20100812102025") + sig.KeyTag = eckey.(*DNSKEY).KeyTag() + sig.SignerName = eckey.(*DNSKEY).Hdr.Name + sig.Algorithm = eckey.(*DNSKEY).Algorithm + + sig.Sign(privkey, []RR{a}) + + t.Logf("%s", sig.String()) + if e := sig.Verify(eckey.(*DNSKEY), []RR{a}); e != nil { + t.Logf("failure to validate: %s", e.Error()) + t.Fail() + } +} + +func TestSignVerifyECDSA2(t *testing.T) { + srv1, err := NewRR("srv.miek.nl. IN SRV 1000 800 0 web1.miek.nl.") + if err != nil { + t.Fatalf(err.Error()) + } + srv := srv1.(*SRV) + + // With this key + key := new(DNSKEY) + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = ECDSAP256SHA256 + privkey, err := key.Generate(256) + if err != nil { + t.Fatal("failure to generate key") + } + + // Fill in the values of the Sig, before signing + sig := new(RRSIG) + sig.Hdr = RR_Header{"miek.nl.", TypeRRSIG, ClassINET, 14400, 0} + sig.TypeCovered = srv.Hdr.Rrtype + sig.Labels = uint8(CountLabel(srv.Hdr.Name)) // works for all 3 + sig.OrigTtl = srv.Hdr.Ttl + sig.Expiration = 1296534305 // date -u '+%s' -d"2011-02-01 04:25:05" + sig.Inception = 1293942305 // date -u '+%s' -d"2011-01-02 04:25:05" + sig.KeyTag = key.KeyTag() // Get the keyfrom the Key + sig.SignerName = key.Hdr.Name + sig.Algorithm = ECDSAP256SHA256 + + if sig.Sign(privkey, []RR{srv}) != nil { + t.Fatal("failure to sign the record") + } + + err = sig.Verify(key, []RR{srv}) + if err != nil { + t.Fatal("Failure to validate:", err) + } +} + +// Here the test vectors from the relevant RFCs are checked. +// rfc6605 6.1 +func TestRFC6605P256(t *testing.T) { + exDNSKEY := `example.net. 3600 IN DNSKEY 257 3 13 ( + GojIhhXUN/u4v54ZQqGSnyhWJwaubCvTmeexv7bR6edb + krSqQpF64cYbcB7wNcP+e+MAnLr+Wi9xMWyQLc8NAA== )` + exPriv := `Private-key-format: v1.2 +Algorithm: 13 (ECDSAP256SHA256) +PrivateKey: GU6SnQ/Ou+xC5RumuIUIuJZteXT2z0O/ok1s38Et6mQ=` + rrDNSKEY, err := NewRR(exDNSKEY) + if err != nil { + t.Fatal(err.Error()) + } + priv, err := rrDNSKEY.(*DNSKEY).NewPrivateKey(exPriv) + if err != nil { + t.Fatal(err.Error()) + } + + exDS := `example.net. 3600 IN DS 55648 13 2 ( + b4c8c1fe2e7477127b27115656ad6256f424625bf5c1 + e2770ce6d6e37df61d17 )` + rrDS, err := NewRR(exDS) + if err != nil { + t.Fatal(err.Error()) + } + ourDS := rrDNSKEY.(*DNSKEY).ToDS(SHA256) + if !reflect.DeepEqual(ourDS, rrDS.(*DS)) { + t.Errorf("DS record differs:\n%v\n%v\n", ourDS, rrDS.(*DS)) + } + + exA := `www.example.net. 3600 IN A 192.0.2.1` + exRRSIG := `www.example.net. 3600 IN RRSIG A 13 3 3600 ( + 20100909100439 20100812100439 55648 example.net. + qx6wLYqmh+l9oCKTN6qIc+bw6ya+KJ8oMz0YP107epXA + yGmt+3SNruPFKG7tZoLBLlUzGGus7ZwmwWep666VCw== )` + rrA, err := NewRR(exA) + if err != nil { + t.Fatal(err.Error()) + } + rrRRSIG, err := NewRR(exRRSIG) + if err != nil { + t.Fatal(err.Error()) + } + if err = rrRRSIG.(*RRSIG).Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil { + t.Errorf("Failure to validate the spec RRSIG: %v", err) + } + + ourRRSIG := &RRSIG{ + Hdr: RR_Header{ + Ttl: rrA.Header().Ttl, + }, + KeyTag: rrDNSKEY.(*DNSKEY).KeyTag(), + SignerName: rrDNSKEY.(*DNSKEY).Hdr.Name, + Algorithm: rrDNSKEY.(*DNSKEY).Algorithm, + } + ourRRSIG.Expiration, _ = StringToTime("20100909100439") + ourRRSIG.Inception, _ = StringToTime("20100812100439") + err = ourRRSIG.Sign(priv, []RR{rrA}) + if err != nil { + t.Fatal(err.Error()) + } + + if err = ourRRSIG.Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil { + t.Errorf("Failure to validate our RRSIG: %v", err) + } + + // Signatures are randomized + rrRRSIG.(*RRSIG).Signature = "" + ourRRSIG.Signature = "" + if !reflect.DeepEqual(ourRRSIG, rrRRSIG.(*RRSIG)) { + t.Fatalf("RRSIG record differs:\n%v\n%v\n", ourRRSIG, rrRRSIG.(*RRSIG)) + } +} + +// rfc6605 6.2 +func TestRFC6605P384(t *testing.T) { + exDNSKEY := `example.net. 3600 IN DNSKEY 257 3 14 ( + xKYaNhWdGOfJ+nPrL8/arkwf2EY3MDJ+SErKivBVSum1 + w/egsXvSADtNJhyem5RCOpgQ6K8X1DRSEkrbYQ+OB+v8 + /uX45NBwY8rp65F6Glur8I/mlVNgF6W/qTI37m40 )` + exPriv := `Private-key-format: v1.2 +Algorithm: 14 (ECDSAP384SHA384) +PrivateKey: WURgWHCcYIYUPWgeLmiPY2DJJk02vgrmTfitxgqcL4vwW7BOrbawVmVe0d9V94SR` + rrDNSKEY, err := NewRR(exDNSKEY) + if err != nil { + t.Fatal(err.Error()) + } + priv, err := rrDNSKEY.(*DNSKEY).NewPrivateKey(exPriv) + if err != nil { + t.Fatal(err.Error()) + } + + exDS := `example.net. 3600 IN DS 10771 14 4 ( + 72d7b62976ce06438e9c0bf319013cf801f09ecc84b8 + d7e9495f27e305c6a9b0563a9b5f4d288405c3008a94 + 6df983d6 )` + rrDS, err := NewRR(exDS) + if err != nil { + t.Fatal(err.Error()) + } + ourDS := rrDNSKEY.(*DNSKEY).ToDS(SHA384) + if !reflect.DeepEqual(ourDS, rrDS.(*DS)) { + t.Fatalf("DS record differs:\n%v\n%v\n", ourDS, rrDS.(*DS)) + } + + exA := `www.example.net. 3600 IN A 192.0.2.1` + exRRSIG := `www.example.net. 3600 IN RRSIG A 14 3 3600 ( + 20100909102025 20100812102025 10771 example.net. + /L5hDKIvGDyI1fcARX3z65qrmPsVz73QD1Mr5CEqOiLP + 95hxQouuroGCeZOvzFaxsT8Glr74hbavRKayJNuydCuz + WTSSPdz7wnqXL5bdcJzusdnI0RSMROxxwGipWcJm )` + rrA, err := NewRR(exA) + if err != nil { + t.Fatal(err.Error()) + } + rrRRSIG, err := NewRR(exRRSIG) + if err != nil { + t.Fatal(err.Error()) + } + if err = rrRRSIG.(*RRSIG).Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil { + t.Errorf("Failure to validate the spec RRSIG: %v", err) + } + + ourRRSIG := &RRSIG{ + Hdr: RR_Header{ + Ttl: rrA.Header().Ttl, + }, + KeyTag: rrDNSKEY.(*DNSKEY).KeyTag(), + SignerName: rrDNSKEY.(*DNSKEY).Hdr.Name, + Algorithm: rrDNSKEY.(*DNSKEY).Algorithm, + } + ourRRSIG.Expiration, _ = StringToTime("20100909102025") + ourRRSIG.Inception, _ = StringToTime("20100812102025") + err = ourRRSIG.Sign(priv, []RR{rrA}) + if err != nil { + t.Fatal(err.Error()) + } + + if err = ourRRSIG.Verify(rrDNSKEY.(*DNSKEY), []RR{rrA}); err != nil { + t.Errorf("Failure to validate our RRSIG: %v", err) + } + + // Signatures are randomized + rrRRSIG.(*RRSIG).Signature = "" + ourRRSIG.Signature = "" + if !reflect.DeepEqual(ourRRSIG, rrRRSIG.(*RRSIG)) { + t.Fatalf("RRSIG record differs:\n%v\n%v\n", ourRRSIG, rrRRSIG.(*RRSIG)) + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/dyn_test.go b/Godeps/_workspace/src/github.com/miekg/dns/dyn_test.go new file mode 100644 index 0000000000..09986a5e4e --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/dyn_test.go @@ -0,0 +1,3 @@ +package dns + +// Find better solution diff --git a/Godeps/_workspace/src/github.com/miekg/dns/edns.go b/Godeps/_workspace/src/github.com/miekg/dns/edns.go new file mode 100644 index 0000000000..e5003e0bde --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/edns.go @@ -0,0 +1,501 @@ +// EDNS0 +// +// EDNS0 is an extension mechanism for the DNS defined in RFC 2671 and updated +// by RFC 6891. It defines an new RR type, the OPT RR, which is then completely +// abused. +// Basic use pattern for creating an (empty) OPT RR: +// +// o := new(dns.OPT) +// o.Hdr.Name = "." // MUST be the root zone, per definition. +// o.Hdr.Rrtype = dns.TypeOPT +// +// The rdata of an OPT RR consists out of a slice of EDNS0 (RFC 6891) +// interfaces. Currently only a few have been standardized: EDNS0_NSID +// (RFC 5001) and EDNS0_SUBNET (draft-vandergaast-edns-client-subnet-02). Note +// that these options may be combined in an OPT RR. +// Basic use pattern for a server to check if (and which) options are set: +// +// // o is a dns.OPT +// for _, s := range o.Option { +// switch e := s.(type) { +// case *dns.EDNS0_NSID: +// // do stuff with e.Nsid +// case *dns.EDNS0_SUBNET: +// // access e.Family, e.Address, etc. +// } +// } +package dns + +import ( + "encoding/hex" + "errors" + "net" + "strconv" +) + +// EDNS0 Option codes. +const ( + EDNS0LLQ = 0x1 // long lived queries: http://tools.ietf.org/html/draft-sekar-dns-llq-01 + EDNS0UL = 0x2 // update lease draft: http://files.dns-sd.org/draft-sekar-dns-ul.txt + EDNS0NSID = 0x3 // nsid (RFC5001) + EDNS0DAU = 0x5 // DNSSEC Algorithm Understood + EDNS0DHU = 0x6 // DS Hash Understood + EDNS0N3U = 0x7 // NSEC3 Hash Understood + EDNS0SUBNET = 0x8 // client-subnet (RFC6891) + EDNS0EXPIRE = 0x9 // EDNS0 expire + EDNS0SUBNETDRAFT = 0x50fa // Don't use! Use EDNS0SUBNET + _DO = 1 << 15 // dnssec ok +) + +type OPT struct { + Hdr RR_Header + Option []EDNS0 `dns:"opt"` +} + +func (rr *OPT) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *OPT) String() string { + s := "\n;; OPT PSEUDOSECTION:\n; EDNS: version " + strconv.Itoa(int(rr.Version())) + "; " + if rr.Do() { + s += "flags: do; " + } else { + s += "flags: ; " + } + s += "udp: " + strconv.Itoa(int(rr.UDPSize())) + + for _, o := range rr.Option { + switch o.(type) { + case *EDNS0_NSID: + s += "\n; NSID: " + o.String() + h, e := o.pack() + var r string + if e == nil { + for _, c := range h { + r += "(" + string(c) + ")" + } + s += " " + r + } + case *EDNS0_SUBNET: + s += "\n; SUBNET: " + o.String() + if o.(*EDNS0_SUBNET).DraftOption { + s += " (draft)" + } + case *EDNS0_UL: + s += "\n; UPDATE LEASE: " + o.String() + case *EDNS0_LLQ: + s += "\n; LONG LIVED QUERIES: " + o.String() + case *EDNS0_DAU: + s += "\n; DNSSEC ALGORITHM UNDERSTOOD: " + o.String() + case *EDNS0_DHU: + s += "\n; DS HASH UNDERSTOOD: " + o.String() + case *EDNS0_N3U: + s += "\n; NSEC3 HASH UNDERSTOOD: " + o.String() + } + } + return s +} + +func (rr *OPT) len() int { + l := rr.Hdr.len() + for i := 0; i < len(rr.Option); i++ { + lo, _ := rr.Option[i].pack() + l += 2 + len(lo) + } + return l +} + +func (rr *OPT) copy() RR { + return &OPT{*rr.Hdr.copyHeader(), rr.Option} +} + +// return the old value -> delete SetVersion? + +// Version returns the EDNS version used. Only zero is defined. +func (rr *OPT) Version() uint8 { + return uint8((rr.Hdr.Ttl & 0x00FF0000) >> 16) +} + +// SetVersion sets the version of EDNS. This is usually zero. +func (rr *OPT) SetVersion(v uint8) { + rr.Hdr.Ttl = rr.Hdr.Ttl&0xFF00FFFF | (uint32(v) << 16) +} + +// ExtendedRcode returns the EDNS extended RCODE field (the upper 8 bits of the TTL). +func (rr *OPT) ExtendedRcode() uint8 { + return uint8((rr.Hdr.Ttl & 0xFF000000) >> 24) +} + +// SetExtendedRcode sets the EDNS extended RCODE field. +func (rr *OPT) SetExtendedRcode(v uint8) { + rr.Hdr.Ttl = rr.Hdr.Ttl&0x00FFFFFF | (uint32(v) << 24) +} + +// UDPSize returns the UDP buffer size. +func (rr *OPT) UDPSize() uint16 { + return rr.Hdr.Class +} + +// SetUDPSize sets the UDP buffer size. +func (rr *OPT) SetUDPSize(size uint16) { + rr.Hdr.Class = size +} + +// Do returns the value of the DO (DNSSEC OK) bit. +func (rr *OPT) Do() bool { + return rr.Hdr.Ttl&_DO == _DO +} + +// SetDo sets the DO (DNSSEC OK) bit. +func (rr *OPT) SetDo() { + rr.Hdr.Ttl |= _DO +} + +// EDNS0 defines an EDNS0 Option. An OPT RR can have multiple options appended to +// it. +type EDNS0 interface { + // Option returns the option code for the option. + Option() uint16 + // pack returns the bytes of the option data. + pack() ([]byte, error) + // unpack sets the data as found in the buffer. Is also sets + // the length of the slice as the length of the option data. + unpack([]byte) error + // String returns the string representation of the option. + String() string +} + +// The nsid EDNS0 option is used to retrieve a nameserver +// identifier. When sending a request Nsid must be set to the empty string +// The identifier is an opaque string encoded as hex. +// Basic use pattern for creating an nsid option: +// +// o := new(dns.OPT) +// o.Hdr.Name = "." +// o.Hdr.Rrtype = dns.TypeOPT +// e := new(dns.EDNS0_NSID) +// e.Code = dns.EDNS0NSID +// e.Nsid = "AA" +// o.Option = append(o.Option, e) +type EDNS0_NSID struct { + Code uint16 // Always EDNS0NSID + Nsid string // This string needs to be hex encoded +} + +func (e *EDNS0_NSID) pack() ([]byte, error) { + h, err := hex.DecodeString(e.Nsid) + if err != nil { + return nil, err + } + return h, nil +} + +func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID } +func (e *EDNS0_NSID) unpack(b []byte) error { e.Nsid = hex.EncodeToString(b); return nil } +func (e *EDNS0_NSID) String() string { return string(e.Nsid) } + +// The subnet EDNS0 option is used to give the remote nameserver +// an idea of where the client lives. It can then give back a different +// answer depending on the location or network topology. +// Basic use pattern for creating an subnet option: +// +// o := new(dns.OPT) +// o.Hdr.Name = "." +// o.Hdr.Rrtype = dns.TypeOPT +// e := new(dns.EDNS0_SUBNET) +// e.Code = dns.EDNS0SUBNET +// e.Family = 1 // 1 for IPv4 source address, 2 for IPv6 +// e.NetMask = 32 // 32 for IPV4, 128 for IPv6 +// e.SourceScope = 0 +// e.Address = net.ParseIP("127.0.0.1").To4() // for IPv4 +// // e.Address = net.ParseIP("2001:7b8:32a::2") // for IPV6 +// o.Option = append(o.Option, e) +type EDNS0_SUBNET struct { + Code uint16 // Always EDNS0SUBNET + Family uint16 // 1 for IP, 2 for IP6 + SourceNetmask uint8 + SourceScope uint8 + Address net.IP + DraftOption bool // Set to true if using the old (0x50fa) option code +} + +func (e *EDNS0_SUBNET) Option() uint16 { + if e.DraftOption { + return EDNS0SUBNETDRAFT + } + return EDNS0SUBNET +} + +func (e *EDNS0_SUBNET) pack() ([]byte, error) { + b := make([]byte, 4) + b[0], b[1] = packUint16(e.Family) + b[2] = e.SourceNetmask + b[3] = e.SourceScope + switch e.Family { + case 1: + if e.SourceNetmask > net.IPv4len*8 { + return nil, errors.New("dns: bad netmask") + } + ip := make([]byte, net.IPv4len) + a := e.Address.To4().Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv4len*8)) + for i := 0; i < net.IPv4len; i++ { + if i+1 > len(e.Address) { + break + } + ip[i] = a[i] + } + needLength := e.SourceNetmask / 8 + if e.SourceNetmask%8 > 0 { + needLength++ + } + ip = ip[:needLength] + b = append(b, ip...) + case 2: + if e.SourceNetmask > net.IPv6len*8 { + return nil, errors.New("dns: bad netmask") + } + ip := make([]byte, net.IPv6len) + a := e.Address.Mask(net.CIDRMask(int(e.SourceNetmask), net.IPv6len*8)) + for i := 0; i < net.IPv6len; i++ { + if i+1 > len(e.Address) { + break + } + ip[i] = a[i] + } + needLength := e.SourceNetmask / 8 + if e.SourceNetmask%8 > 0 { + needLength++ + } + ip = ip[:needLength] + b = append(b, ip...) + default: + return nil, errors.New("dns: bad address family") + } + return b, nil +} + +func (e *EDNS0_SUBNET) unpack(b []byte) error { + lb := len(b) + if lb < 4 { + return ErrBuf + } + e.Family, _ = unpackUint16(b, 0) + e.SourceNetmask = b[2] + e.SourceScope = b[3] + switch e.Family { + case 1: + addr := make([]byte, 4) + for i := 0; i < int(e.SourceNetmask/8); i++ { + if 4+i > len(b) { + return ErrBuf + } + addr[i] = b[4+i] + } + e.Address = net.IPv4(addr[0], addr[1], addr[2], addr[3]) + case 2: + addr := make([]byte, 16) + for i := 0; i < int(e.SourceNetmask/8); i++ { + if 4+i > len(b) { + return ErrBuf + } + addr[i] = b[4+i] + } + e.Address = net.IP{addr[0], addr[1], addr[2], addr[3], addr[4], + addr[5], addr[6], addr[7], addr[8], addr[9], addr[10], + addr[11], addr[12], addr[13], addr[14], addr[15]} + } + return nil +} + +func (e *EDNS0_SUBNET) String() (s string) { + if e.Address == nil { + s = "" + } else if e.Address.To4() != nil { + s = e.Address.String() + } else { + s = "[" + e.Address.String() + "]" + } + s += "/" + strconv.Itoa(int(e.SourceNetmask)) + "/" + strconv.Itoa(int(e.SourceScope)) + return +} + +// The UL (Update Lease) EDNS0 (draft RFC) option is used to tell the server to set +// an expiration on an update RR. This is helpful for clients that cannot clean +// up after themselves. This is a draft RFC and more information can be found at +// http://files.dns-sd.org/draft-sekar-dns-ul.txt +// +// o := new(dns.OPT) +// o.Hdr.Name = "." +// o.Hdr.Rrtype = dns.TypeOPT +// e := new(dns.EDNS0_UL) +// e.Code = dns.EDNS0UL +// e.Lease = 120 // in seconds +// o.Option = append(o.Option, e) +type EDNS0_UL struct { + Code uint16 // Always EDNS0UL + Lease uint32 +} + +func (e *EDNS0_UL) Option() uint16 { return EDNS0UL } +func (e *EDNS0_UL) String() string { return strconv.FormatUint(uint64(e.Lease), 10) } + +// Copied: http://golang.org/src/pkg/net/dnsmsg.go +func (e *EDNS0_UL) pack() ([]byte, error) { + b := make([]byte, 4) + b[0] = byte(e.Lease >> 24) + b[1] = byte(e.Lease >> 16) + b[2] = byte(e.Lease >> 8) + b[3] = byte(e.Lease) + return b, nil +} + +func (e *EDNS0_UL) unpack(b []byte) error { + if len(b) < 4 { + return ErrBuf + } + e.Lease = uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) + return nil +} + +// Long Lived Queries: http://tools.ietf.org/html/draft-sekar-dns-llq-01 +// Implemented for completeness, as the EDNS0 type code is assigned. +type EDNS0_LLQ struct { + Code uint16 // Always EDNS0LLQ + Version uint16 + Opcode uint16 + Error uint16 + Id uint64 + LeaseLife uint32 +} + +func (e *EDNS0_LLQ) Option() uint16 { return EDNS0LLQ } + +func (e *EDNS0_LLQ) pack() ([]byte, error) { + b := make([]byte, 18) + b[0], b[1] = packUint16(e.Version) + b[2], b[3] = packUint16(e.Opcode) + b[4], b[5] = packUint16(e.Error) + b[6] = byte(e.Id >> 56) + b[7] = byte(e.Id >> 48) + b[8] = byte(e.Id >> 40) + b[9] = byte(e.Id >> 32) + b[10] = byte(e.Id >> 24) + b[11] = byte(e.Id >> 16) + b[12] = byte(e.Id >> 8) + b[13] = byte(e.Id) + b[14] = byte(e.LeaseLife >> 24) + b[15] = byte(e.LeaseLife >> 16) + b[16] = byte(e.LeaseLife >> 8) + b[17] = byte(e.LeaseLife) + return b, nil +} + +func (e *EDNS0_LLQ) unpack(b []byte) error { + if len(b) < 18 { + return ErrBuf + } + e.Version, _ = unpackUint16(b, 0) + e.Opcode, _ = unpackUint16(b, 2) + e.Error, _ = unpackUint16(b, 4) + e.Id = uint64(b[6])<<56 | uint64(b[6+1])<<48 | uint64(b[6+2])<<40 | + uint64(b[6+3])<<32 | uint64(b[6+4])<<24 | uint64(b[6+5])<<16 | uint64(b[6+6])<<8 | uint64(b[6+7]) + e.LeaseLife = uint32(b[14])<<24 | uint32(b[14+1])<<16 | uint32(b[14+2])<<8 | uint32(b[14+3]) + return nil +} + +func (e *EDNS0_LLQ) String() string { + s := strconv.FormatUint(uint64(e.Version), 10) + " " + strconv.FormatUint(uint64(e.Opcode), 10) + + " " + strconv.FormatUint(uint64(e.Error), 10) + " " + strconv.FormatUint(uint64(e.Id), 10) + + " " + strconv.FormatUint(uint64(e.LeaseLife), 10) + return s +} + +type EDNS0_DAU struct { + Code uint16 // Always EDNS0DAU + AlgCode []uint8 +} + +func (e *EDNS0_DAU) Option() uint16 { return EDNS0DAU } +func (e *EDNS0_DAU) pack() ([]byte, error) { return e.AlgCode, nil } +func (e *EDNS0_DAU) unpack(b []byte) error { e.AlgCode = b; return nil } + +func (e *EDNS0_DAU) String() string { + s := "" + for i := 0; i < len(e.AlgCode); i++ { + if a, ok := AlgorithmToString[e.AlgCode[i]]; ok { + s += " " + a + } else { + s += " " + strconv.Itoa(int(e.AlgCode[i])) + } + } + return s +} + +type EDNS0_DHU struct { + Code uint16 // Always EDNS0DHU + AlgCode []uint8 +} + +func (e *EDNS0_DHU) Option() uint16 { return EDNS0DHU } +func (e *EDNS0_DHU) pack() ([]byte, error) { return e.AlgCode, nil } +func (e *EDNS0_DHU) unpack(b []byte) error { e.AlgCode = b; return nil } + +func (e *EDNS0_DHU) String() string { + s := "" + for i := 0; i < len(e.AlgCode); i++ { + if a, ok := HashToString[e.AlgCode[i]]; ok { + s += " " + a + } else { + s += " " + strconv.Itoa(int(e.AlgCode[i])) + } + } + return s +} + +type EDNS0_N3U struct { + Code uint16 // Always EDNS0N3U + AlgCode []uint8 +} + +func (e *EDNS0_N3U) Option() uint16 { return EDNS0N3U } +func (e *EDNS0_N3U) pack() ([]byte, error) { return e.AlgCode, nil } +func (e *EDNS0_N3U) unpack(b []byte) error { e.AlgCode = b; return nil } + +func (e *EDNS0_N3U) String() string { + // Re-use the hash map + s := "" + for i := 0; i < len(e.AlgCode); i++ { + if a, ok := HashToString[e.AlgCode[i]]; ok { + s += " " + a + } else { + s += " " + strconv.Itoa(int(e.AlgCode[i])) + } + } + return s +} + +type EDNS0_EXPIRE struct { + Code uint16 // Always EDNS0EXPIRE + Expire uint32 +} + +func (e *EDNS0_EXPIRE) Option() uint16 { return EDNS0EXPIRE } +func (e *EDNS0_EXPIRE) String() string { return strconv.FormatUint(uint64(e.Expire), 10) } + +func (e *EDNS0_EXPIRE) pack() ([]byte, error) { + b := make([]byte, 4) + b[0] = byte(e.Expire >> 24) + b[1] = byte(e.Expire >> 16) + b[2] = byte(e.Expire >> 8) + b[3] = byte(e.Expire) + return b, nil +} + +func (e *EDNS0_EXPIRE) unpack(b []byte) error { + if len(b) < 4 { + return ErrBuf + } + e.Expire = uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) + return nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/edns_test.go b/Godeps/_workspace/src/github.com/miekg/dns/edns_test.go new file mode 100644 index 0000000000..8ee82ab426 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/edns_test.go @@ -0,0 +1,48 @@ +package dns + +import "testing" + +func TestOPTTtl(t *testing.T) { + e := &OPT{} + e.Hdr.Name = "." + e.Hdr.Rrtype = TypeOPT + + if e.Do() { + t.Fail() + } + + e.SetDo() + if !e.Do() { + t.Fail() + } + + oldTtl := e.Hdr.Ttl + + if e.Version() != 0 { + t.Fail() + } + + e.SetVersion(42) + if e.Version() != 42 { + t.Fail() + } + + e.SetVersion(0) + if e.Hdr.Ttl != oldTtl { + t.Fail() + } + + if e.ExtendedRcode() != 0 { + t.Fail() + } + + e.SetExtendedRcode(42) + if e.ExtendedRcode() != 42 { + t.Fail() + } + + e.SetExtendedRcode(0) + if e.Hdr.Ttl != oldTtl { + t.Fail() + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/example_test.go b/Godeps/_workspace/src/github.com/miekg/dns/example_test.go new file mode 100644 index 0000000000..216b8b39c3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/example_test.go @@ -0,0 +1,147 @@ +package dns_test + +import ( + "errors" + "fmt" + "github.com/miekg/dns" + "log" + "net" +) + +// Retrieve the MX records for miek.nl. +func ExampleMX() { + config, _ := dns.ClientConfigFromFile("/etc/resolv.conf") + c := new(dns.Client) + m := new(dns.Msg) + m.SetQuestion("miek.nl.", dns.TypeMX) + m.RecursionDesired = true + r, _, err := c.Exchange(m, config.Servers[0]+":"+config.Port) + if err != nil { + return + } + if r.Rcode != dns.RcodeSuccess { + return + } + for _, a := range r.Answer { + if mx, ok := a.(*dns.MX); ok { + fmt.Printf("%s\n", mx.String()) + } + } +} + +// Retrieve the DNSKEY records of a zone and convert them +// to DS records for SHA1, SHA256 and SHA384. +func ExampleDS(zone string) { + config, _ := dns.ClientConfigFromFile("/etc/resolv.conf") + c := new(dns.Client) + m := new(dns.Msg) + if zone == "" { + zone = "miek.nl" + } + m.SetQuestion(dns.Fqdn(zone), dns.TypeDNSKEY) + m.SetEdns0(4096, true) + r, _, err := c.Exchange(m, config.Servers[0]+":"+config.Port) + if err != nil { + return + } + if r.Rcode != dns.RcodeSuccess { + return + } + for _, k := range r.Answer { + if key, ok := k.(*dns.DNSKEY); ok { + for _, alg := range []int{dns.SHA1, dns.SHA256, dns.SHA384} { + fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags) + } + } + } +} + +const TypeAPAIR = 0x0F99 + +type APAIR struct { + addr [2]net.IP +} + +func NewAPAIR() dns.PrivateRdata { return new(APAIR) } + +func (rd *APAIR) String() string { return rd.addr[0].String() + " " + rd.addr[1].String() } +func (rd *APAIR) Parse(txt []string) error { + if len(txt) != 2 { + return errors.New("two addresses required for APAIR") + } + for i, s := range txt { + ip := net.ParseIP(s) + if ip == nil { + return errors.New("invalid IP in APAIR text representation") + } + rd.addr[i] = ip + } + return nil +} + +func (rd *APAIR) Pack(buf []byte) (int, error) { + b := append([]byte(rd.addr[0]), []byte(rd.addr[1])...) + n := copy(buf, b) + if n != len(b) { + return n, dns.ErrBuf + } + return n, nil +} + +func (rd *APAIR) Unpack(buf []byte) (int, error) { + ln := net.IPv4len * 2 + if len(buf) != ln { + return 0, errors.New("invalid length of APAIR rdata") + } + cp := make([]byte, ln) + copy(cp, buf) // clone bytes to use them in IPs + + rd.addr[0] = net.IP(cp[:3]) + rd.addr[1] = net.IP(cp[4:]) + + return len(buf), nil +} + +func (rd *APAIR) Copy(dest dns.PrivateRdata) error { + cp := make([]byte, rd.Len()) + _, err := rd.Pack(cp) + if err != nil { + return err + } + + d := dest.(*APAIR) + d.addr[0] = net.IP(cp[:3]) + d.addr[1] = net.IP(cp[4:]) + return nil +} + +func (rd *APAIR) Len() int { + return net.IPv4len * 2 +} + +func ExamplePrivateHandle() { + dns.PrivateHandle("APAIR", TypeAPAIR, NewAPAIR) + defer dns.PrivateHandleRemove(TypeAPAIR) + + rr, err := dns.NewRR("miek.nl. APAIR (1.2.3.4 1.2.3.5)") + if err != nil { + log.Fatal("could not parse APAIR record: ", err) + } + fmt.Println(rr) + // Output: miek.nl. 3600 IN APAIR 1.2.3.4 1.2.3.5 + + m := new(dns.Msg) + m.Id = 12345 + m.SetQuestion("miek.nl.", TypeAPAIR) + m.Answer = append(m.Answer, rr) + + fmt.Println(m) + // ;; opcode: QUERY, status: NOERROR, id: 12345 + // ;; flags: rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 + // + // ;; QUESTION SECTION: + // ;miek.nl. IN APAIR + // + // ;; ANSWER SECTION: + // miek.nl. 3600 IN APAIR 1.2.3.4 1.2.3.5 +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/idn/example_test.go b/Godeps/_workspace/src/github.com/miekg/dns/idn/example_test.go new file mode 100644 index 0000000000..8833cd91de --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/idn/example_test.go @@ -0,0 +1,18 @@ +package idn_test + +import ( + "fmt" + "github.com/miekg/dns/idn" +) + +func ExampleToPunycode() { + name := "インターネット.テスト" + fmt.Printf("%s -> %s", name, idn.ToPunycode(name)) + // Output: インターネット.テスト -> xn--eckucmux0ukc.xn--zckzah +} + +func ExampleFromPunycode() { + name := "xn--mgbaja8a1hpac.xn--mgbachtv" + fmt.Printf("%s -> %s", name, idn.FromPunycode(name)) + // Output: xn--mgbaja8a1hpac.xn--mgbachtv -> الانترنت.اختبار +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/idn/punycode.go b/Godeps/_workspace/src/github.com/miekg/dns/idn/punycode.go new file mode 100644 index 0000000000..faab4027b9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/idn/punycode.go @@ -0,0 +1,268 @@ +// Package idn implements encoding from and to punycode as speficied by RFC 3492. +package idn + +import ( + "bytes" + "github.com/miekg/dns" + "strings" + "unicode" +) + +// Implementation idea from RFC itself and from from IDNA::Punycode created by +// Tatsuhiko Miyagawa and released under Perl Artistic +// License in 2002. + +const ( + _MIN rune = 1 + _MAX rune = 26 + _SKEW rune = 38 + _BASE rune = 36 + _BIAS rune = 72 + _N rune = 128 + _DAMP rune = 700 + + _DELIMITER = '-' + _PREFIX = "xn--" +) + +// ToPunycode converts unicode domain names to DNS-appropriate punycode names. +// This function would return incorrect result for strings for non-canonical +// unicode strings. +func ToPunycode(s string) string { + tokens := dns.SplitDomainName(s) + switch { + case s == "": + return "" + case tokens == nil: // s == . + return "." + case s[len(s)-1] == '.': + tokens = append(tokens, "") + } + + for i := range tokens { + tokens[i] = string(encode([]byte(tokens[i]))) + } + return strings.Join(tokens, ".") +} + +// FromPunycode returns unicode domain name from provided punycode string. +func FromPunycode(s string) string { + tokens := dns.SplitDomainName(s) + switch { + case s == "": + return "" + case tokens == nil: // s == . + return "." + case s[len(s)-1] == '.': + tokens = append(tokens, "") + } + for i := range tokens { + tokens[i] = string(decode([]byte(tokens[i]))) + } + return strings.Join(tokens, ".") +} + +// digitval converts single byte into meaningful value that's used to calculate decoded unicode character. +const errdigit = 0xffff + +func digitval(code rune) rune { + switch { + case code >= 'A' && code <= 'Z': + return code - 'A' + case code >= 'a' && code <= 'z': + return code - 'a' + case code >= '0' && code <= '9': + return code - '0' + 26 + } + return errdigit +} + +// lettercode finds BASE36 byte (a-z0-9) based on calculated number. +func lettercode(digit rune) rune { + switch { + case digit >= 0 && digit <= 25: + return digit + 'a' + case digit >= 26 && digit <= 36: + return digit - 26 + '0' + } + panic("dns: not reached") +} + +// adapt calculates next bias to be used for next iteration delta. +func adapt(delta rune, numpoints int, firsttime bool) rune { + if firsttime { + delta /= _DAMP + } else { + delta /= 2 + } + + var k rune + for delta = delta + delta/rune(numpoints); delta > (_BASE-_MIN)*_MAX/2; k += _BASE { + delta /= _BASE - _MIN + } + + return k + ((_BASE-_MIN+1)*delta)/(delta+_SKEW) +} + +// next finds minimal rune (one with lowest codepoint value) that should be equal or above boundary. +func next(b []rune, boundary rune) rune { + if len(b) == 0 { + panic("dns: invalid set of runes to determine next one") + } + m := b[0] + for _, x := range b[1:] { + if x >= boundary && (m < boundary || x < m) { + m = x + } + } + return m +} + +// preprune converts unicode rune to lower case. At this time it's not +// supporting all things described in RFCs +func preprune(r rune) rune { + if unicode.IsUpper(r) { + r = unicode.ToLower(r) + } + return r +} + +// tfunc is a function that helps calculate each character weight +func tfunc(k, bias rune) rune { + switch { + case k <= bias: + return _MIN + case k >= bias+_MAX: + return _MAX + } + return k - bias +} + +// encode transforms Unicode input bytes (that represent DNS label) into punycode bytestream +func encode(input []byte) []byte { + n, bias := _N, _BIAS + + b := bytes.Runes(input) + for i := range b { + b[i] = preprune(b[i]) + } + + basic := make([]byte, 0, len(b)) + for _, ltr := range b { + if ltr <= 0x7f { + basic = append(basic, byte(ltr)) + } + } + basiclen := len(basic) + fulllen := len(b) + if basiclen == fulllen { + return basic + } + + var out bytes.Buffer + + out.WriteString(_PREFIX) + if basiclen > 0 { + out.Write(basic) + out.WriteByte(_DELIMITER) + } + + var ( + ltr, nextltr rune + delta, q rune // delta calculation (see rfc) + t, k, cp rune // weight and codepoint calculation + ) + + s := &bytes.Buffer{} + for h := basiclen; h < fulllen; n, delta = n+1, delta+1 { + nextltr = next(b, n) + s.Truncate(0) + s.WriteRune(nextltr) + delta, n = delta+(nextltr-n)*rune(h+1), nextltr + + for _, ltr = range b { + if ltr < n { + delta++ + } + if ltr == n { + q = delta + for k = _BASE; ; k += _BASE { + t = tfunc(k, bias) + if q < t { + break + } + cp = t + ((q - t) % (_BASE - t)) + out.WriteRune(lettercode(cp)) + q = (q - t) / (_BASE - t) + } + + out.WriteRune(lettercode(q)) + + bias = adapt(delta, h+1, h == basiclen) + h, delta = h+1, 0 + } + } + } + return out.Bytes() +} + +// decode transforms punycode input bytes (that represent DNS label) into Unicode bytestream +func decode(b []byte) []byte { + src := b // b would move and we need to keep it + + n, bias := _N, _BIAS + if !bytes.HasPrefix(b, []byte(_PREFIX)) { + return b + } + out := make([]rune, 0, len(b)) + b = b[len(_PREFIX):] + for pos, x := range b { + if x == _DELIMITER { + out = append(out, bytes.Runes(b[:pos])...) + b = b[pos+1:] // trim source string + break + } + } + if len(b) == 0 { + return src + } + var ( + i, oldi, w rune + ch byte + t, digit rune + ln int + ) + + for i = 0; len(b) > 0; i++ { + oldi, w = i, 1 + for k := _BASE; len(b) > 0; k += _BASE { + ch, b = b[0], b[1:] + digit = digitval(rune(ch)) + if digit == errdigit { + return src + } + i += digit * w + + t = tfunc(k, bias) + if digit < t { + break + } + + w *= _BASE - t + } + ln = len(out) + 1 + bias = adapt(i-oldi, ln, oldi == 0) + n += i / rune(ln) + i = i % rune(ln) + // insert + out = append(out, 0) + copy(out[i+1:], out[i:]) + out[i] = n + } + + var ret bytes.Buffer + for _, r := range out { + ret.WriteRune(r) + } + return ret.Bytes() +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/idn/punycode_test.go b/Godeps/_workspace/src/github.com/miekg/dns/idn/punycode_test.go new file mode 100644 index 0000000000..3202450aa5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/idn/punycode_test.go @@ -0,0 +1,94 @@ +package idn + +import ( + "strings" + "testing" +) + +var testcases = [][2]string{ + {"", ""}, + {"a", "a"}, + {"A-B", "a-b"}, + {"AbC", "abc"}, + {"я", "xn--41a"}, + {"zя", "xn--z-0ub"}, + {"ЯZ", "xn--z-zub"}, + {"إختبار", "xn--kgbechtv"}, + {"آزمایشی", "xn--hgbk6aj7f53bba"}, + {"测试", "xn--0zwm56d"}, + {"測試", "xn--g6w251d"}, + {"Испытание", "xn--80akhbyknj4f"}, + {"परीक्षा", "xn--11b5bs3a9aj6g"}, + {"δοκιμή", "xn--jxalpdlp"}, + {"테스트", "xn--9t4b11yi5a"}, + {"טעסט", "xn--deba0ad"}, + {"テスト", "xn--zckzah"}, + {"பரிட்சை", "xn--hlcj6aya9esc7a"}, +} + +func TestEncodeDecodePunycode(t *testing.T) { + for _, tst := range testcases { + enc := encode([]byte(tst[0])) + if string(enc) != tst[1] { + t.Errorf("%s encodeded as %s but should be %s", tst[0], enc, tst[1]) + } + dec := decode([]byte(tst[1])) + if string(dec) != strings.ToLower(tst[0]) { + t.Errorf("%s decoded as %s but should be %s", tst[1], dec, strings.ToLower(tst[0])) + } + } +} + +func TestToFromPunycode(t *testing.T) { + for _, tst := range testcases { + // assert unicode.com == punycode.com + full := ToPunycode(tst[0] + ".com") + if full != tst[1]+".com" { + t.Errorf("invalid result from string conversion to punycode, %s and should be %s.com", full, tst[1]) + } + // assert punycode.punycode == unicode.unicode + decoded := FromPunycode(tst[1] + "." + tst[1]) + if decoded != strings.ToLower(tst[0]+"."+tst[0]) { + t.Errorf("invalid result from string conversion to punycode, %s and should be %s.%s", decoded, tst[0], tst[0]) + } + } +} + +func TestEncodeDecodeFinalPeriod(t *testing.T) { + for _, tst := range testcases { + // assert unicode.com. == punycode.com. + full := ToPunycode(tst[0] + ".") + if full != tst[1]+"." { + t.Errorf("invalid result from string conversion to punycode when period added at the end, %#v and should be %#v", full, tst[1]+".") + } + // assert punycode.com. == unicode.com. + decoded := FromPunycode(tst[1] + ".") + if decoded != strings.ToLower(tst[0]+".") { + t.Errorf("invalid result from string conversion to punycode when period added, %#v and should be %#v", decoded, tst[0]+".") + } + full = ToPunycode(tst[0]) + if full != tst[1] { + t.Errorf("invalid result from string conversion to punycode when no period added at the end, %#v and should be %#v", full, tst[1]+".") + } + // assert punycode.com. == unicode.com. + decoded = FromPunycode(tst[1]) + if decoded != strings.ToLower(tst[0]) { + t.Errorf("invalid result from string conversion to punycode when no period added, %#v and should be %#v", decoded, tst[0]+".") + } + } +} + +var invalid = []string{ + "xn--*", + "xn--", + "xn---", +} + +func TestInvalidPunycode(t *testing.T) { + for _, d := range invalid { + s := FromPunycode(d) + if s != d { + t.Errorf("Changed invalid name %s to %#v", d, s) + } + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/keygen.go b/Godeps/_workspace/src/github.com/miekg/dns/keygen.go new file mode 100644 index 0000000000..bc1a079779 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/keygen.go @@ -0,0 +1,149 @@ +package dns + +import ( + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "math/big" + "strconv" +) + +const _FORMAT = "Private-key-format: v1.3\n" + +// Empty interface that is used as a wrapper around all possible +// private key implementations from the crypto package. +type PrivateKey interface{} + +// Generate generates a DNSKEY of the given bit size. +// The public part is put inside the DNSKEY record. +// The Algorithm in the key must be set as this will define +// what kind of DNSKEY will be generated. +// The ECDSA algorithms imply a fixed keysize, in that case +// bits should be set to the size of the algorithm. +func (r *DNSKEY) Generate(bits int) (PrivateKey, error) { + switch r.Algorithm { + case DSA, DSANSEC3SHA1: + if bits != 1024 { + return nil, ErrKeySize + } + case RSAMD5, RSASHA1, RSASHA256, RSASHA1NSEC3SHA1: + if bits < 512 || bits > 4096 { + return nil, ErrKeySize + } + case RSASHA512: + if bits < 1024 || bits > 4096 { + return nil, ErrKeySize + } + case ECDSAP256SHA256: + if bits != 256 { + return nil, ErrKeySize + } + case ECDSAP384SHA384: + if bits != 384 { + return nil, ErrKeySize + } + } + + switch r.Algorithm { + case DSA, DSANSEC3SHA1: + params := new(dsa.Parameters) + if err := dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160); err != nil { + return nil, err + } + priv := new(dsa.PrivateKey) + priv.PublicKey.Parameters = *params + err := dsa.GenerateKey(priv, rand.Reader) + if err != nil { + return nil, err + } + r.setPublicKeyDSA(params.Q, params.P, params.G, priv.PublicKey.Y) + return priv, nil + case RSAMD5, RSASHA1, RSASHA256, RSASHA512, RSASHA1NSEC3SHA1: + priv, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, err + } + r.setPublicKeyRSA(priv.PublicKey.E, priv.PublicKey.N) + return priv, nil + case ECDSAP256SHA256, ECDSAP384SHA384: + var c elliptic.Curve + switch r.Algorithm { + case ECDSAP256SHA256: + c = elliptic.P256() + case ECDSAP384SHA384: + c = elliptic.P384() + } + priv, err := ecdsa.GenerateKey(c, rand.Reader) + if err != nil { + return nil, err + } + r.setPublicKeyCurve(priv.PublicKey.X, priv.PublicKey.Y) + return priv, nil + default: + return nil, ErrAlg + } + return nil, nil // Dummy return +} + +// PrivateKeyString converts a PrivateKey to a string. This +// string has the same format as the private-key-file of BIND9 (Private-key-format: v1.3). +// It needs some info from the key (hashing, keytag), so its a method of the DNSKEY. +func (r *DNSKEY) PrivateKeyString(p PrivateKey) (s string) { + switch t := p.(type) { + case *rsa.PrivateKey: + algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")" + modulus := unpackBase64(t.PublicKey.N.Bytes()) + e := big.NewInt(int64(t.PublicKey.E)) + publicExponent := unpackBase64(e.Bytes()) + privateExponent := unpackBase64(t.D.Bytes()) + prime1 := unpackBase64(t.Primes[0].Bytes()) + prime2 := unpackBase64(t.Primes[1].Bytes()) + // Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm + // and from: http://code.google.com/p/go/issues/detail?id=987 + one := big.NewInt(1) + minusone := big.NewInt(-1) + p_1 := big.NewInt(0).Sub(t.Primes[0], one) + q_1 := big.NewInt(0).Sub(t.Primes[1], one) + exp1 := big.NewInt(0).Mod(t.D, p_1) + exp2 := big.NewInt(0).Mod(t.D, q_1) + coeff := big.NewInt(0).Exp(t.Primes[1], minusone, t.Primes[0]) + + exponent1 := unpackBase64(exp1.Bytes()) + exponent2 := unpackBase64(exp2.Bytes()) + coefficient := unpackBase64(coeff.Bytes()) + + s = _FORMAT + + "Algorithm: " + algorithm + "\n" + + "Modules: " + modulus + "\n" + + "PublicExponent: " + publicExponent + "\n" + + "PrivateExponent: " + privateExponent + "\n" + + "Prime1: " + prime1 + "\n" + + "Prime2: " + prime2 + "\n" + + "Exponent1: " + exponent1 + "\n" + + "Exponent2: " + exponent2 + "\n" + + "Coefficient: " + coefficient + "\n" + case *ecdsa.PrivateKey: + algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")" + private := unpackBase64(t.D.Bytes()) + s = _FORMAT + + "Algorithm: " + algorithm + "\n" + + "PrivateKey: " + private + "\n" + case *dsa.PrivateKey: + algorithm := strconv.Itoa(int(r.Algorithm)) + " (" + AlgorithmToString[r.Algorithm] + ")" + prime := unpackBase64(t.PublicKey.Parameters.P.Bytes()) + subprime := unpackBase64(t.PublicKey.Parameters.Q.Bytes()) + base := unpackBase64(t.PublicKey.Parameters.G.Bytes()) + priv := unpackBase64(t.X.Bytes()) + pub := unpackBase64(t.PublicKey.Y.Bytes()) + s = _FORMAT + + "Algorithm: " + algorithm + "\n" + + "Prime(p): " + prime + "\n" + + "Subprime(q): " + subprime + "\n" + + "Base(g): " + base + "\n" + + "Private_value(x): " + priv + "\n" + + "Public_value(y): " + pub + "\n" + } + return +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/kscan.go b/Godeps/_workspace/src/github.com/miekg/dns/kscan.go new file mode 100644 index 0000000000..8cc729abce --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/kscan.go @@ -0,0 +1,244 @@ +package dns + +import ( + "crypto/dsa" + "crypto/ecdsa" + "crypto/rsa" + "io" + "math/big" + "strings" +) + +func (k *DNSKEY) NewPrivateKey(s string) (PrivateKey, error) { + if s[len(s)-1] != '\n' { // We need a closing newline + return k.ReadPrivateKey(strings.NewReader(s+"\n"), "") + } + return k.ReadPrivateKey(strings.NewReader(s), "") +} + +// ReadPrivateKey reads a private key from the io.Reader q. The string file is +// only used in error reporting. +// The public key must be +// known, because some cryptographic algorithms embed the public inside the privatekey. +func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (PrivateKey, error) { + m, e := parseKey(q, file) + if m == nil { + return nil, e + } + if _, ok := m["private-key-format"]; !ok { + return nil, ErrPrivKey + } + if m["private-key-format"] != "v1.2" && m["private-key-format"] != "v1.3" { + return nil, ErrPrivKey + } + // TODO(mg): check if the pubkey matches the private key + switch m["algorithm"] { + case "3 (DSA)": + p, e := readPrivateKeyDSA(m) + if e != nil { + return nil, e + } + if !k.setPublicKeyInPrivate(p) { + return nil, ErrPrivKey + } + return p, e + case "1 (RSAMD5)": + fallthrough + case "5 (RSASHA1)": + fallthrough + case "7 (RSASHA1NSEC3SHA1)": + fallthrough + case "8 (RSASHA256)": + fallthrough + case "10 (RSASHA512)": + p, e := readPrivateKeyRSA(m) + if e != nil { + return nil, e + } + if !k.setPublicKeyInPrivate(p) { + return nil, ErrPrivKey + } + return p, e + case "12 (ECC-GOST)": + p, e := readPrivateKeyGOST(m) + if e != nil { + return nil, e + } + // setPublicKeyInPrivate(p) + return p, e + case "13 (ECDSAP256SHA256)": + fallthrough + case "14 (ECDSAP384SHA384)": + p, e := readPrivateKeyECDSA(m) + if e != nil { + return nil, e + } + if !k.setPublicKeyInPrivate(p) { + return nil, ErrPrivKey + } + return p, e + } + return nil, ErrPrivKey +} + +// Read a private key (file) string and create a public key. Return the private key. +func readPrivateKeyRSA(m map[string]string) (PrivateKey, error) { + p := new(rsa.PrivateKey) + p.Primes = []*big.Int{nil, nil} + for k, v := range m { + switch k { + case "modulus", "publicexponent", "privateexponent", "prime1", "prime2": + v1, err := packBase64([]byte(v)) + if err != nil { + return nil, err + } + switch k { + case "modulus": + p.PublicKey.N = big.NewInt(0) + p.PublicKey.N.SetBytes(v1) + case "publicexponent": + i := big.NewInt(0) + i.SetBytes(v1) + p.PublicKey.E = int(i.Int64()) // int64 should be large enough + case "privateexponent": + p.D = big.NewInt(0) + p.D.SetBytes(v1) + case "prime1": + p.Primes[0] = big.NewInt(0) + p.Primes[0].SetBytes(v1) + case "prime2": + p.Primes[1] = big.NewInt(0) + p.Primes[1].SetBytes(v1) + } + case "exponent1", "exponent2", "coefficient": + // not used in Go (yet) + case "created", "publish", "activate": + // not used in Go (yet) + } + } + return p, nil +} + +func readPrivateKeyDSA(m map[string]string) (PrivateKey, error) { + p := new(dsa.PrivateKey) + p.X = big.NewInt(0) + for k, v := range m { + switch k { + case "private_value(x)": + v1, err := packBase64([]byte(v)) + if err != nil { + return nil, err + } + p.X.SetBytes(v1) + case "created", "publish", "activate": + /* not used in Go (yet) */ + } + } + return p, nil +} + +func readPrivateKeyECDSA(m map[string]string) (PrivateKey, error) { + p := new(ecdsa.PrivateKey) + p.D = big.NewInt(0) + // TODO: validate that the required flags are present + for k, v := range m { + switch k { + case "privatekey": + v1, err := packBase64([]byte(v)) + if err != nil { + return nil, err + } + p.D.SetBytes(v1) + case "created", "publish", "activate": + /* not used in Go (yet) */ + } + } + return p, nil +} + +func readPrivateKeyGOST(m map[string]string) (PrivateKey, error) { + // TODO(miek) + return nil, nil +} + +// parseKey reads a private key from r. It returns a map[string]string, +// with the key-value pairs, or an error when the file is not correct. +func parseKey(r io.Reader, file string) (map[string]string, error) { + s := scanInit(r) + m := make(map[string]string) + c := make(chan lex) + k := "" + // Start the lexer + go klexer(s, c) + for l := range c { + // It should alternate + switch l.value { + case _KEY: + k = l.token + case _VALUE: + if k == "" { + return nil, &ParseError{file, "no private key seen", l} + } + //println("Setting", strings.ToLower(k), "to", l.token, "b") + m[strings.ToLower(k)] = l.token + k = "" + } + } + return m, nil +} + +// klexer scans the sourcefile and returns tokens on the channel c. +func klexer(s *scan, c chan lex) { + var l lex + str := "" // Hold the current read text + commt := false + key := true + x, err := s.tokenText() + defer close(c) + for err == nil { + l.column = s.position.Column + l.line = s.position.Line + switch x { + case ':': + if commt { + break + } + l.token = str + if key { + l.value = _KEY + c <- l + // Next token is a space, eat it + s.tokenText() + key = false + str = "" + } else { + l.value = _VALUE + } + case ';': + commt = true + case '\n': + if commt { + // Reset a comment + commt = false + } + l.value = _VALUE + l.token = str + c <- l + str = "" + commt = false + key = true + default: + if commt { + break + } + str += string(x) + } + x, err = s.tokenText() + } + if len(str) > 0 { + // Send remainder + l.token = str + l.value = _VALUE + c <- l + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/labels.go b/Godeps/_workspace/src/github.com/miekg/dns/labels.go new file mode 100644 index 0000000000..758e5783de --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/labels.go @@ -0,0 +1,162 @@ +package dns + +// Holds a bunch of helper functions for dealing with labels. + +// SplitDomainName splits a name string into it's labels. +// www.miek.nl. returns []string{"www", "miek", "nl"} +// The root label (.) returns nil. Note that using +// strings.Split(s) will work in most cases, but does not handle +// escaped dots (\.) for instance. +func SplitDomainName(s string) (labels []string) { + if len(s) == 0 { + return nil + } + fqdnEnd := 0 // offset of the final '.' or the length of the name + idx := Split(s) + begin := 0 + if s[len(s)-1] == '.' { + fqdnEnd = len(s) - 1 + } else { + fqdnEnd = len(s) + } + + switch len(idx) { + case 0: + return nil + case 1: + // no-op + default: + end := 0 + for i := 1; i < len(idx); i++ { + end = idx[i] + labels = append(labels, s[begin:end-1]) + begin = end + } + } + + labels = append(labels, s[begin:fqdnEnd]) + return labels +} + +// CompareDomainName compares the names s1 and s2 and +// returns how many labels they have in common starting from the *right*. +// The comparison stops at the first inequality. The names are not downcased +// before the comparison. +// +// www.miek.nl. and miek.nl. have two labels in common: miek and nl +// www.miek.nl. and www.bla.nl. have one label in common: nl +func CompareDomainName(s1, s2 string) (n int) { + s1 = Fqdn(s1) + s2 = Fqdn(s2) + l1 := Split(s1) + l2 := Split(s2) + + // the first check: root label + if l1 == nil || l2 == nil { + return + } + + j1 := len(l1) - 1 // end + i1 := len(l1) - 2 // start + j2 := len(l2) - 1 + i2 := len(l2) - 2 + // the second check can be done here: last/only label + // before we fall through into the for-loop below + if s1[l1[j1]:] == s2[l2[j2]:] { + n++ + } else { + return + } + for { + if i1 < 0 || i2 < 0 { + break + } + if s1[l1[i1]:l1[j1]] == s2[l2[i2]:l2[j2]] { + n++ + } else { + break + } + j1-- + i1-- + j2-- + i2-- + } + return +} + +// CountLabel counts the the number of labels in the string s. +func CountLabel(s string) (labels int) { + if s == "." { + return + } + off := 0 + end := false + for { + off, end = NextLabel(s, off) + labels++ + if end { + return + } + } + panic("dns: not reached") +} + +// Split splits a name s into its label indexes. +// www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}. +// The root name (.) returns nil. Also see dns.SplitDomainName. +func Split(s string) []int { + if s == "." { + return nil + } + idx := make([]int, 1, 3) + off := 0 + end := false + + for { + off, end = NextLabel(s, off) + if end { + return idx + } + idx = append(idx, off) + } + panic("dns: not reached") +} + +// NextLabel returns the index of the start of the next label in the +// string s starting at offset. +// The bool end is true when the end of the string has been reached. +func NextLabel(s string, offset int) (i int, end bool) { + quote := false + for i = offset; i < len(s)-1; i++ { + switch s[i] { + case '\\': + quote = !quote + default: + quote = false + case '.': + if quote { + quote = !quote + continue + } + return i + 1, false + } + } + return i + 1, true +} + +// PrevLabel returns the index of the label when starting from the right and +// jumping n labels to the left. +// The bool start is true when the start of the string has been overshot. +func PrevLabel(s string, n int) (i int, start bool) { + if n == 0 { + return len(s), false + } + lab := Split(s) + if lab == nil { + return 0, true + } + if n > len(lab) { + return 0, true + } + return lab[len(lab)-n], false +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/labels_test.go b/Godeps/_workspace/src/github.com/miekg/dns/labels_test.go new file mode 100644 index 0000000000..1d8da155fd --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/labels_test.go @@ -0,0 +1,214 @@ +package dns + +import ( + "testing" +) + +func TestCompareDomainName(t *testing.T) { + s1 := "www.miek.nl." + s2 := "miek.nl." + s3 := "www.bla.nl." + s4 := "nl.www.bla." + s5 := "nl" + s6 := "miek.nl" + + if CompareDomainName(s1, s2) != 2 { + t.Logf("%s with %s should be %d", s1, s2, 2) + t.Fail() + } + if CompareDomainName(s1, s3) != 1 { + t.Logf("%s with %s should be %d", s1, s3, 1) + t.Fail() + } + if CompareDomainName(s3, s4) != 0 { + t.Logf("%s with %s should be %d", s3, s4, 0) + t.Fail() + } + // Non qualified tests + if CompareDomainName(s1, s5) != 1 { + t.Logf("%s with %s should be %d", s1, s5, 1) + t.Fail() + } + if CompareDomainName(s1, s6) != 2 { + t.Logf("%s with %s should be %d", s1, s5, 2) + t.Fail() + } + + if CompareDomainName(s1, ".") != 0 { + t.Logf("%s with %s should be %d", s1, s5, 0) + t.Fail() + } + if CompareDomainName(".", ".") != 0 { + t.Logf("%s with %s should be %d", ".", ".", 0) + t.Fail() + } +} + +func TestSplit(t *testing.T) { + splitter := map[string]int{ + "www.miek.nl.": 3, + "www.miek.nl": 3, + "www..miek.nl": 4, + `www\.miek.nl.`: 2, + `www\\.miek.nl.`: 3, + ".": 0, + "nl.": 1, + "nl": 1, + "com.": 1, + ".com.": 2, + } + for s, i := range splitter { + if x := len(Split(s)); x != i { + t.Logf("labels should be %d, got %d: %s %v\n", i, x, s, Split(s)) + t.Fail() + } else { + t.Logf("%s %v\n", s, Split(s)) + } + } +} + +func TestSplit2(t *testing.T) { + splitter := map[string][]int{ + "www.miek.nl.": []int{0, 4, 9}, + "www.miek.nl": []int{0, 4, 9}, + "nl": []int{0}, + } + for s, i := range splitter { + x := Split(s) + switch len(i) { + case 1: + if x[0] != i[0] { + t.Logf("labels should be %v, got %v: %s\n", i, x, s) + t.Fail() + } + default: + if x[0] != i[0] || x[1] != i[1] || x[2] != i[2] { + t.Logf("labels should be %v, got %v: %s\n", i, x, s) + t.Fail() + } + } + } +} + +func TestPrevLabel(t *testing.T) { + type prev struct { + string + int + } + prever := map[prev]int{ + prev{"www.miek.nl.", 0}: 12, + prev{"www.miek.nl.", 1}: 9, + prev{"www.miek.nl.", 2}: 4, + + prev{"www.miek.nl", 0}: 11, + prev{"www.miek.nl", 1}: 9, + prev{"www.miek.nl", 2}: 4, + + prev{"www.miek.nl.", 5}: 0, + prev{"www.miek.nl", 5}: 0, + + prev{"www.miek.nl.", 3}: 0, + prev{"www.miek.nl", 3}: 0, + } + for s, i := range prever { + x, ok := PrevLabel(s.string, s.int) + if i != x { + t.Logf("label should be %d, got %d, %t: preving %d, %s\n", i, x, ok, s.int, s.string) + t.Fail() + } + } +} + +func TestCountLabel(t *testing.T) { + splitter := map[string]int{ + "www.miek.nl.": 3, + "www.miek.nl": 3, + "nl": 1, + ".": 0, + } + for s, i := range splitter { + x := CountLabel(s) + if x != i { + t.Logf("CountLabel should have %d, got %d\n", i, x) + t.Fail() + } + } +} + +func TestSplitDomainName(t *testing.T) { + labels := map[string][]string{ + "miek.nl": []string{"miek", "nl"}, + ".": nil, + "www.miek.nl.": []string{"www", "miek", "nl"}, + "www.miek.nl": []string{"www", "miek", "nl"}, + "www..miek.nl": []string{"www", "", "miek", "nl"}, + `www\.miek.nl`: []string{`www\.miek`, "nl"}, + `www\\.miek.nl`: []string{`www\\`, "miek", "nl"}, + } +domainLoop: + for domain, splits := range labels { + parts := SplitDomainName(domain) + if len(parts) != len(splits) { + t.Logf("SplitDomainName returned %v for %s, expected %v", parts, domain, splits) + t.Fail() + continue domainLoop + } + for i := range parts { + if parts[i] != splits[i] { + t.Logf("SplitDomainName returned %v for %s, expected %v", parts, domain, splits) + t.Fail() + continue domainLoop + } + } + } +} + +func TestIsDomainName(t *testing.T) { + type ret struct { + ok bool + lab int + } + names := map[string]*ret{ + "..": &ret{false, 1}, + "@.": &ret{true, 1}, + "www.example.com": &ret{true, 3}, + "www.e%ample.com": &ret{true, 3}, + "www.example.com.": &ret{true, 3}, + "mi\\k.nl.": &ret{true, 2}, + "mi\\k.nl": &ret{true, 2}, + } + for d, ok := range names { + l, k := IsDomainName(d) + if ok.ok != k || ok.lab != l { + t.Logf(" got %v %d for %s ", k, l, d) + t.Logf("have %v %d for %s ", ok.ok, ok.lab, d) + t.Fail() + } + } +} + +func BenchmarkSplitLabels(b *testing.B) { + for i := 0; i < b.N; i++ { + Split("www.example.com") + } +} + +func BenchmarkLenLabels(b *testing.B) { + for i := 0; i < b.N; i++ { + CountLabel("www.example.com") + } +} + +func BenchmarkCompareLabels(b *testing.B) { + for i := 0; i < b.N; i++ { + CompareDomainName("www.example.com", "aa.example.com") + } +} + +func BenchmarkIsSubDomain(b *testing.B) { + for i := 0; i < b.N; i++ { + IsSubDomain("www.example.com", "aa.example.com") + IsSubDomain("example.com", "aa.example.com") + IsSubDomain("miek.nl", "aa.example.com") + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/msg.go b/Godeps/_workspace/src/github.com/miekg/dns/msg.go new file mode 100644 index 0000000000..2b12cd4801 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/msg.go @@ -0,0 +1,1881 @@ +// DNS packet assembly, see RFC 1035. Converting from - Unpack() - +// and to - Pack() - wire format. +// All the packers and unpackers take a (msg []byte, off int) +// and return (off1 int, ok bool). If they return ok==false, they +// also return off1==len(msg), so that the next unpacker will +// also fail. This lets us avoid checks of ok until the end of a +// packing sequence. + +package dns + +import ( + "encoding/base32" + "encoding/base64" + "encoding/hex" + "fmt" + "math/rand" + "net" + "reflect" + "strconv" + "time" +) + +const maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer + +var ( + ErrAlg error = &Error{err: "bad algorithm"} + ErrAuth error = &Error{err: "bad authentication"} + ErrBuf error = &Error{err: "buffer size too small"} + ErrConnEmpty error = &Error{err: "conn has no connection"} + ErrConn error = &Error{err: "conn holds both UDP and TCP connection"} + ErrExtendedRcode error = &Error{err: "bad extended rcode"} + ErrFqdn error = &Error{err: "domain must be fully qualified"} + ErrId error = &Error{err: "id mismatch"} + ErrKeyAlg error = &Error{err: "bad key algorithm"} + ErrKey error = &Error{err: "bad key"} + ErrKeySize error = &Error{err: "bad key size"} + ErrNoSig error = &Error{err: "no signature found"} + ErrPrivKey error = &Error{err: "bad private key"} + ErrRcode error = &Error{err: "bad rcode"} + ErrRdata error = &Error{err: "bad rdata"} + ErrRRset error = &Error{err: "bad rrset"} + ErrSecret error = &Error{err: "no secrets defined"} + ErrServ error = &Error{err: "no servers could be reached"} + ErrShortRead error = &Error{err: "short read"} + ErrSig error = &Error{err: "bad signature"} + ErrSigGen error = &Error{err: "bad signature generation"} + ErrSoa error = &Error{err: "no SOA"} + ErrTime error = &Error{err: "bad time"} +) + +// Id, by default, returns a 16 bits random number to be used as a +// message id. The random provided should be good enough. This being a +// variable the function can be reassigned to a custom function. +// For instance, to make it return a static value: +// +// dns.Id = func() uint16 { return 3 } +var Id func() uint16 = id + +// A manually-unpacked version of (id, bits). +// This is in its own struct for easy printing. +type MsgHdr struct { + Id uint16 + Response bool + Opcode int + Authoritative bool + Truncated bool + RecursionDesired bool + RecursionAvailable bool + Zero bool + AuthenticatedData bool + CheckingDisabled bool + Rcode int +} + +// The layout of a DNS message. +type Msg struct { + MsgHdr + Compress bool `json:"-"` // If true, the message will be compressed when converted to wire format. This not part of the official DNS packet format. + Question []Question // Holds the RR(s) of the question section. + Answer []RR // Holds the RR(s) of the answer section. + Ns []RR // Holds the RR(s) of the authority section. + Extra []RR // Holds the RR(s) of the additional section. +} + +// Map of strings for each RR wire type. +var TypeToString = map[uint16]string{ + TypeA: "A", + TypeAAAA: "AAAA", + TypeAFSDB: "AFSDB", + TypeANY: "ANY", // Meta RR + TypeATMA: "ATMA", + TypeAXFR: "AXFR", // Meta RR + TypeCAA: "CAA", + TypeCDS: "CDS", + TypeCERT: "CERT", + TypeCNAME: "CNAME", + TypeDHCID: "DHCID", + TypeDLV: "DLV", + TypeDNAME: "DNAME", + TypeDNSKEY: "DNSKEY", + TypeDS: "DS", + TypeEID: "EID", + TypeEUI48: "EUI48", + TypeEUI64: "EUI64", + TypeGID: "GID", + TypeGPOS: "GPOS", + TypeHINFO: "HINFO", + TypeHIP: "HIP", + TypeIPSECKEY: "IPSECKEY", + TypeISDN: "ISDN", + TypeIXFR: "IXFR", // Meta RR + TypeKX: "KX", + TypeL32: "L32", + TypeL64: "L64", + TypeLOC: "LOC", + TypeLP: "LP", + TypeMB: "MB", + TypeMD: "MD", + TypeMF: "MF", + TypeMG: "MG", + TypeMINFO: "MINFO", + TypeMR: "MR", + TypeMX: "MX", + TypeNAPTR: "NAPTR", + TypeNID: "NID", + TypeNINFO: "NINFO", + TypeNIMLOC: "NIMLOC", + TypeNS: "NS", + TypeNSAP: "NSAP", + TypeNSAPPTR: "NSAP-PTR", + TypeNSEC3: "NSEC3", + TypeNSEC3PARAM: "NSEC3PARAM", + TypeNSEC: "NSEC", + TypeNULL: "NULL", + TypeOPT: "OPT", + TypeOPENPGPKEY: "OPENPGPKEY", + TypePTR: "PTR", + TypeRKEY: "RKEY", + TypeRP: "RP", + TypeRRSIG: "RRSIG", + TypeRT: "RT", + TypeSOA: "SOA", + TypeSPF: "SPF", + TypeSRV: "SRV", + TypeSSHFP: "SSHFP", + TypeTA: "TA", + TypeTALINK: "TALINK", + TypeTKEY: "TKEY", // Meta RR + TypeTLSA: "TLSA", + TypeTSIG: "TSIG", // Meta RR + TypeTXT: "TXT", + TypePX: "PX", + TypeUID: "UID", + TypeUINFO: "UINFO", + TypeUNSPEC: "UNSPEC", + TypeURI: "URI", + TypeWKS: "WKS", + TypeX25: "X25", +} + +// Reverse, needed for string parsing. +var StringToType = reverseInt16(TypeToString) +var StringToClass = reverseInt16(ClassToString) + +// Map of opcodes strings. +var StringToOpcode = reverseInt(OpcodeToString) + +// Map of rcodes strings. +var StringToRcode = reverseInt(RcodeToString) + +// Map of strings for each CLASS wire type. +var ClassToString = map[uint16]string{ + ClassINET: "IN", + ClassCSNET: "CS", + ClassCHAOS: "CH", + ClassHESIOD: "HS", + ClassNONE: "NONE", + ClassANY: "ANY", +} + +// Map of strings for opcodes. +var OpcodeToString = map[int]string{ + OpcodeQuery: "QUERY", + OpcodeIQuery: "IQUERY", + OpcodeStatus: "STATUS", + OpcodeNotify: "NOTIFY", + OpcodeUpdate: "UPDATE", +} + +// Map of strings for rcodes. +var RcodeToString = map[int]string{ + RcodeSuccess: "NOERROR", + RcodeFormatError: "FORMERR", + RcodeServerFailure: "SERVFAIL", + RcodeNameError: "NXDOMAIN", + RcodeNotImplemented: "NOTIMPL", + RcodeRefused: "REFUSED", + RcodeYXDomain: "YXDOMAIN", // From RFC 2136 + RcodeYXRrset: "YXRRSET", + RcodeNXRrset: "NXRRSET", + RcodeNotAuth: "NOTAUTH", + RcodeNotZone: "NOTZONE", + RcodeBadSig: "BADSIG", // Also known as RcodeBadVers, see RFC 6891 + // RcodeBadVers: "BADVERS", + RcodeBadKey: "BADKEY", + RcodeBadTime: "BADTIME", + RcodeBadMode: "BADMODE", + RcodeBadName: "BADNAME", + RcodeBadAlg: "BADALG", + RcodeBadTrunc: "BADTRUNC", +} + +// Rather than write the usual handful of routines to pack and +// unpack every message that can appear on the wire, we use +// reflection to write a generic pack/unpack for structs and then +// use it. Thus, if in the future we need to define new message +// structs, no new pack/unpack/printing code needs to be written. + +// Domain names are a sequence of counted strings +// split at the dots. They end with a zero-length string. + +// PackDomainName packs a domain name s into msg[off:]. +// If compression is wanted compress must be true and the compression +// map needs to hold a mapping between domain names and offsets +// pointing into msg[]. +func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { + off1, _, err = packDomainName(s, msg, off, compression, compress) + return +} + +func packDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, labels int, err error) { + // special case if msg == nil + lenmsg := 256 + if msg != nil { + lenmsg = len(msg) + } + ls := len(s) + if ls == 0 { // Ok, for instance when dealing with update RR without any rdata. + return off, 0, nil + } + // If not fully qualified, error out, but only if msg == nil #ugly + switch { + case msg == nil: + if s[ls-1] != '.' { + s += "." + ls++ + } + case msg != nil: + if s[ls-1] != '.' { + return lenmsg, 0, ErrFqdn + } + } + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // Except for escaped dots (\.), which are normal dots. + // There is also a trailing zero. + + // Compression + nameoffset := -1 + pointer := -1 + // Emit sequence of counted strings, chopping at dots. + begin := 0 + bs := []byte(s) + ro_bs, bs_fresh, escaped_dot := s, true, false + for i := 0; i < ls; i++ { + if bs[i] == '\\' { + for j := i; j < ls-1; j++ { + bs[j] = bs[j+1] + } + ls-- + if off+1 > lenmsg { + return lenmsg, labels, ErrBuf + } + // check for \DDD + if i+2 < ls && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) { + bs[i] = dddToByte(bs[i:]) + for j := i + 1; j < ls-2; j++ { + bs[j] = bs[j+2] + } + ls -= 2 + } else if bs[i] == 't' { + bs[i] = '\t' + } else if bs[i] == 'r' { + bs[i] = '\r' + } else if bs[i] == 'n' { + bs[i] = '\n' + } + escaped_dot = bs[i] == '.' + bs_fresh = false + continue + } + + if bs[i] == '.' { + if i > 0 && bs[i-1] == '.' && !escaped_dot { + // two dots back to back is not legal + return lenmsg, labels, ErrRdata + } + if i-begin >= 1<<6 { // top two bits of length must be clear + return lenmsg, labels, ErrRdata + } + // off can already (we're in a loop) be bigger than len(msg) + // this happens when a name isn't fully qualified + if off+1 > lenmsg { + return lenmsg, labels, ErrBuf + } + if msg != nil { + msg[off] = byte(i - begin) + } + offset := off + off++ + for j := begin; j < i; j++ { + if off+1 > lenmsg { + return lenmsg, labels, ErrBuf + } + if msg != nil { + msg[off] = bs[j] + } + off++ + } + if compress && !bs_fresh { + ro_bs = string(bs) + bs_fresh = true + } + // Dont try to compress '.' + if compress && ro_bs[begin:] != "." { + if p, ok := compression[ro_bs[begin:]]; !ok { + // Only offsets smaller than this can be used. + if offset < maxCompressionOffset { + compression[ro_bs[begin:]] = offset + } + } else { + // The first hit is the longest matching dname + // keep the pointer offset we get back and store + // the offset of the current name, because that's + // where we need to insert the pointer later + + // If compress is true, we're allowed to compress this dname + if pointer == -1 && compress { + pointer = p // Where to point to + nameoffset = offset // Where to point from + break + } + } + } + labels++ + begin = i + 1 + } + escaped_dot = false + } + // Root label is special + if len(bs) == 1 && bs[0] == '.' { + return off, labels, nil + } + // If we did compression and we find something add the pointer here + if pointer != -1 { + // We have two bytes (14 bits) to put the pointer in + // if msg == nil, we will never do compression + msg[nameoffset], msg[nameoffset+1] = packUint16(uint16(pointer ^ 0xC000)) + off = nameoffset + 1 + goto End + } + if msg != nil { + msg[off] = 0 + } +End: + off++ + return off, labels, nil +} + +// Unpack a domain name. +// In addition to the simple sequences of counted strings above, +// domain names are allowed to refer to strings elsewhere in the +// packet, to avoid repeating common suffixes when returning +// many entries in a single domain. The pointers are marked +// by a length byte with the top two bits set. Ignoring those +// two bits, that byte and the next give a 14 bit offset from msg[0] +// where we should pick up the trail. +// Note that if we jump elsewhere in the packet, +// we return off1 == the offset after the first pointer we found, +// which is where the next record will start. +// In theory, the pointers are only allowed to jump backward. +// We let them jump anywhere and stop jumping after a while. + +// UnpackDomainName unpacks a domain name into a string. +func UnpackDomainName(msg []byte, off int) (string, int, error) { + s := make([]byte, 0, 64) + off1 := 0 + lenmsg := len(msg) + ptr := 0 // number of pointers followed +Loop: + for { + if off >= lenmsg { + return "", lenmsg, ErrBuf + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // end of name + if len(s) == 0 { + return ".", off, nil + } + break Loop + } + // literal string + if off+c > lenmsg { + return "", lenmsg, ErrBuf + } + for j := off; j < off+c; j++ { + switch b := msg[j]; b { + case '.', '(', ')', ';', ' ', '@': + fallthrough + case '"', '\\': + s = append(s, '\\', b) + case '\t': + s = append(s, '\\', 't') + case '\r': + s = append(s, '\\', 'r') + default: + if b < 32 || b >= 127 { // unprintable use \DDD + s = append(s, fmt.Sprintf("\\%03d", b)...) + } else { + s = append(s, b) + } + } + } + s = append(s, '.') + off += c + case 0xC0: + // pointer to somewhere else in msg. + // remember location after first ptr, + // since that's how many bytes we consumed. + // also, don't follow too many pointers -- + // maybe there's a loop. + if off >= lenmsg { + return "", lenmsg, ErrBuf + } + c1 := msg[off] + off++ + if ptr == 0 { + off1 = off + } + if ptr++; ptr > 10 { + return "", lenmsg, &Error{err: "too many compression pointers"} + } + off = (c^0xC0)<<8 | int(c1) + default: + // 0x80 and 0x40 are reserved + return "", lenmsg, ErrRdata + } + } + if ptr == 0 { + off1 = off + } + return string(s), off1, nil +} + +func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) { + var err error + if len(txt) == 0 { + if offset >= len(msg) { + return offset, ErrBuf + } + msg[offset] = 0 + return offset, nil + } + for i := range txt { + if len(txt[i]) > len(tmp) { + return offset, ErrBuf + } + offset, err = packTxtString(txt[i], msg, offset, tmp) + if err != nil { + return offset, err + } + } + return offset, err +} + +func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) { + lenByteOffset := offset + if offset >= len(msg) { + return offset, ErrBuf + } + offset++ + bs := tmp[:len(s)] + copy(bs, s) + for i := 0; i < len(bs); i++ { + if len(msg) <= offset { + return offset, ErrBuf + } + if bs[i] == '\\' { + i++ + if i == len(bs) { + break + } + // check for \DDD + if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) { + msg[offset] = dddToByte(bs[i:]) + i += 2 + } else if bs[i] == 't' { + msg[offset] = '\t' + } else if bs[i] == 'r' { + msg[offset] = '\r' + } else if bs[i] == 'n' { + msg[offset] = '\n' + } else { + msg[offset] = bs[i] + } + } else { + msg[offset] = bs[i] + } + offset++ + } + l := offset - lenByteOffset - 1 + if l > 255 { + return offset, &Error{err: "string exceeded 255 bytes in txt"} + } + msg[lenByteOffset] = byte(l) + return offset, nil +} + +func unpackTxt(msg []byte, offset, rdend int) ([]string, int, error) { + var err error + var ss []string + var s string + for offset < rdend && err == nil { + s, offset, err = unpackTxtString(msg, offset) + if err == nil { + ss = append(ss, s) + } + } + return ss, offset, err +} + +func unpackTxtString(msg []byte, offset int) (string, int, error) { + if offset+1 > len(msg) { + return "", offset, &Error{err: "overflow unpacking txt"} + } + l := int(msg[offset]) + if offset+l+1 > len(msg) { + return "", offset, &Error{err: "overflow unpacking txt"} + } + s := make([]byte, 0, l) + for _, b := range msg[offset+1 : offset+1+l] { + switch b { + case '"', '\\': + s = append(s, '\\', b) + case '\t': + s = append(s, `\t`...) + case '\r': + s = append(s, `\r`...) + case '\n': + s = append(s, `\n`...) + default: + if b < 32 || b > 127 { // unprintable + s = append(s, fmt.Sprintf("\\%03d", b)...) + } else { + s = append(s, b) + } + } + } + offset += 1 + l + return string(s), offset, nil +} + +// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string, +// slices and other (often anonymous) structs. +func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { + var txtTmp []byte + lenmsg := len(msg) + numfield := val.NumField() + for i := 0; i < numfield; i++ { + typefield := val.Type().Field(i) + if typefield.Tag == `dns:"-"` { + continue + } + switch fv := val.Field(i); fv.Kind() { + default: + return lenmsg, &Error{err: "bad kind packing"} + case reflect.Interface: + // PrivateRR is the only RR implementation that has interface field. + // therefore it's expected that this interface would be PrivateRdata + switch data := fv.Interface().(type) { + case PrivateRdata: + n, err := data.Pack(msg[off:]) + if err != nil { + return lenmsg, err + } + off += n + default: + return lenmsg, &Error{err: "bad kind interface packing"} + } + case reflect.Slice: + switch typefield.Tag { + default: + return lenmsg, &Error{"bad tag packing slice: " + typefield.Tag.Get("dns")} + case `dns:"domain-name"`: + for j := 0; j < val.Field(i).Len(); j++ { + element := val.Field(i).Index(j).String() + off, err = PackDomainName(element, msg, off, compression, false && compress) + if err != nil { + return lenmsg, err + } + } + case `dns:"txt"`: + if txtTmp == nil { + txtTmp = make([]byte, 256*4+1) + } + off, err = packTxt(fv.Interface().([]string), msg, off, txtTmp) + if err != nil { + return lenmsg, err + } + case `dns:"opt"`: // edns + for j := 0; j < val.Field(i).Len(); j++ { + element := val.Field(i).Index(j).Interface() + b, e := element.(EDNS0).pack() + if e != nil { + return lenmsg, &Error{err: "overflow packing opt"} + } + // Option code + msg[off], msg[off+1] = packUint16(element.(EDNS0).Option()) + // Length + msg[off+2], msg[off+3] = packUint16(uint16(len(b))) + off += 4 + if off+len(b) > lenmsg { + copy(msg[off:], b) + off = lenmsg + continue + } + // Actual data + copy(msg[off:off+len(b)], b) + off += len(b) + } + case `dns:"a"`: + // It must be a slice of 4, even if it is 16, we encode + // only the first 4 + if off+net.IPv4len > lenmsg { + return lenmsg, &Error{err: "overflow packing a"} + } + switch fv.Len() { + case net.IPv6len: + msg[off] = byte(fv.Index(12).Uint()) + msg[off+1] = byte(fv.Index(13).Uint()) + msg[off+2] = byte(fv.Index(14).Uint()) + msg[off+3] = byte(fv.Index(15).Uint()) + off += net.IPv4len + case net.IPv4len: + msg[off] = byte(fv.Index(0).Uint()) + msg[off+1] = byte(fv.Index(1).Uint()) + msg[off+2] = byte(fv.Index(2).Uint()) + msg[off+3] = byte(fv.Index(3).Uint()) + off += net.IPv4len + case 0: + // Allowed, for dynamic updates + default: + return lenmsg, &Error{err: "overflow packing a"} + } + case `dns:"aaaa"`: + if fv.Len() == 0 { + break + } + if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg { + return lenmsg, &Error{err: "overflow packing aaaa"} + } + for j := 0; j < net.IPv6len; j++ { + msg[off] = byte(fv.Index(j).Uint()) + off++ + } + case `dns:"wks"`: + if off == lenmsg { + break // dyn. updates + } + if val.Field(i).Len() == 0 { + break + } + var bitmapbyte uint16 + for j := 0; j < val.Field(i).Len(); j++ { + serv := uint16((fv.Index(j).Uint())) + bitmapbyte = uint16(serv / 8) + if int(bitmapbyte) > lenmsg { + return lenmsg, &Error{err: "overflow packing wks"} + } + bit := uint16(serv) - bitmapbyte*8 + msg[bitmapbyte] = byte(1 << (7 - bit)) + } + off += int(bitmapbyte) + case `dns:"nsec"`: // NSEC/NSEC3 + // This is the uint16 type bitmap + if val.Field(i).Len() == 0 { + // Do absolutely nothing + break + } + + lastwindow := uint16(0) + length := uint16(0) + if off+2 > lenmsg { + return lenmsg, &Error{err: "overflow packing nsecx"} + } + for j := 0; j < val.Field(i).Len(); j++ { + t := uint16((fv.Index(j).Uint())) + window := uint16(t / 256) + if lastwindow != window { + // New window, jump to the new offset + off += int(length) + 3 + if off > lenmsg { + return lenmsg, &Error{err: "overflow packing nsecx bitmap"} + } + } + length = (t - window*256) / 8 + bit := t - (window * 256) - (length * 8) + if off+2+int(length) > lenmsg { + return lenmsg, &Error{err: "overflow packing nsecx bitmap"} + } + + // Setting the window # + msg[off] = byte(window) + // Setting the octets length + msg[off+1] = byte(length + 1) + // Setting the bit value for the type in the right octet + msg[off+2+int(length)] |= byte(1 << (7 - bit)) + lastwindow = window + } + off += 2 + int(length) + off++ + if off > lenmsg { + return lenmsg, &Error{err: "overflow packing nsecx bitmap"} + } + } + case reflect.Struct: + off, err = packStructValue(fv, msg, off, compression, compress) + if err != nil { + return lenmsg, err + } + case reflect.Uint8: + if off+1 > lenmsg { + return lenmsg, &Error{err: "overflow packing uint8"} + } + msg[off] = byte(fv.Uint()) + off++ + case reflect.Uint16: + if off+2 > lenmsg { + return lenmsg, &Error{err: "overflow packing uint16"} + } + i := fv.Uint() + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > lenmsg { + return lenmsg, &Error{err: "overflow packing uint32"} + } + i := fv.Uint() + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + case reflect.Uint64: + switch typefield.Tag { + default: + if off+8 > lenmsg { + return lenmsg, &Error{err: "overflow packing uint64"} + } + i := fv.Uint() + msg[off] = byte(i >> 56) + msg[off+1] = byte(i >> 48) + msg[off+2] = byte(i >> 40) + msg[off+3] = byte(i >> 32) + msg[off+4] = byte(i >> 24) + msg[off+5] = byte(i >> 16) + msg[off+6] = byte(i >> 8) + msg[off+7] = byte(i) + off += 8 + case `dns:"uint48"`: + // Used in TSIG, where it stops at 48 bits, so we discard the upper 16 + if off+6 > lenmsg { + return lenmsg, &Error{err: "overflow packing uint64 as uint48"} + } + i := fv.Uint() + msg[off] = byte(i >> 40) + msg[off+1] = byte(i >> 32) + msg[off+2] = byte(i >> 24) + msg[off+3] = byte(i >> 16) + msg[off+4] = byte(i >> 8) + msg[off+5] = byte(i) + off += 6 + } + case reflect.String: + // There are multiple string encodings. + // The tag distinguishes ordinary strings from domain names. + s := fv.String() + switch typefield.Tag { + default: + return lenmsg, &Error{"bad tag packing string: " + typefield.Tag.Get("dns")} + case `dns:"base64"`: + b64, e := packBase64([]byte(s)) + if e != nil { + return lenmsg, e + } + copy(msg[off:off+len(b64)], b64) + off += len(b64) + case `dns:"domain-name"`: + if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil { + return lenmsg, err + } + case `dns:"cdomain-name"`: + if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil { + return lenmsg, err + } + case `dns:"size-base32"`: + // This is purely for NSEC3 atm, the previous byte must + // holds the length of the encoded string. As NSEC3 + // is only defined to SHA1, the hashlength is 20 (160 bits) + msg[off-1] = 20 + fallthrough + case `dns:"base32"`: + b32, e := packBase32([]byte(s)) + if e != nil { + return lenmsg, e + } + copy(msg[off:off+len(b32)], b32) + off += len(b32) + case `dns:"size-hex"`: + fallthrough + case `dns:"hex"`: + // There is no length encoded here + h, e := hex.DecodeString(s) + if e != nil { + return lenmsg, e + } + if off+hex.DecodedLen(len(s)) > lenmsg { + return lenmsg, &Error{err: "overflow packing hex"} + } + copy(msg[off:off+hex.DecodedLen(len(s))], h) + off += hex.DecodedLen(len(s)) + case `dns:"size"`: + // the size is already encoded in the RR, we can safely use the + // length of string. String is RAW (not encoded in hex, nor base64) + copy(msg[off:off+len(s)], s) + off += len(s) + case `dns:"txt"`: + fallthrough + case "": + if txtTmp == nil { + txtTmp = make([]byte, 256*4+1) + } + off, err = packTxtString(fv.String(), msg, off, txtTmp) + if err != nil { + return lenmsg, err + } + } + } + } + return off, nil +} + +func structValue(any interface{}) reflect.Value { + return reflect.ValueOf(any).Elem() +} + +// PackStruct packs any structure to wire format. +func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) { + off, err = packStructValue(structValue(any), msg, off, nil, false) + return off, err +} + +func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { + off, err = packStructValue(structValue(any), msg, off, compression, compress) + return off, err +} + +// TODO(miek): Fix use of rdlength here + +// Unpack a reflect.StructValue from msg. +// Same restrictions as packStructValue. +func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { + var rdend int + lenmsg := len(msg) + for i := 0; i < val.NumField(); i++ { + if off > lenmsg { + return lenmsg, &Error{"bad offset unpacking"} + } + switch fv := val.Field(i); fv.Kind() { + default: + return lenmsg, &Error{err: "bad kind unpacking"} + case reflect.Interface: + // PrivateRR is the only RR implementation that has interface field. + // therefore it's expected that this interface would be PrivateRdata + switch data := fv.Interface().(type) { + case PrivateRdata: + n, err := data.Unpack(msg[off:rdend]) + if err != nil { + return lenmsg, err + } + off += n + default: + return lenmsg, &Error{err: "bad kind interface unpacking"} + } + case reflect.Slice: + switch val.Type().Field(i).Tag { + default: + return lenmsg, &Error{"bad tag unpacking slice: " + val.Type().Field(i).Tag.Get("dns")} + case `dns:"domain-name"`: + // HIP record slice of name (or none) + servers := make([]string, 0) + var s string + for off < rdend { + s, off, err = UnpackDomainName(msg, off) + if err != nil { + return lenmsg, err + } + servers = append(servers, s) + } + fv.Set(reflect.ValueOf(servers)) + case `dns:"txt"`: + if off == lenmsg || rdend == off { + break + } + var txt []string + txt, off, err = unpackTxt(msg, off, rdend) + if err != nil { + return lenmsg, err + } + fv.Set(reflect.ValueOf(txt)) + case `dns:"opt"`: // edns0 + if off == rdend { + // This is an EDNS0 (OPT Record) with no rdata + // We can safely return here. + break + } + edns := make([]EDNS0, 0) + Option: + code := uint16(0) + if off+2 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking opt"} + } + code, off = unpackUint16(msg, off) + optlen, off1 := unpackUint16(msg, off) + if off1+int(optlen) > rdend { + return lenmsg, &Error{err: "overflow unpacking opt"} + } + switch code { + case EDNS0NSID: + e := new(EDNS0_NSID) + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) + off = off1 + int(optlen) + case EDNS0SUBNET, EDNS0SUBNETDRAFT: + e := new(EDNS0_SUBNET) + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) + off = off1 + int(optlen) + if code == EDNS0SUBNETDRAFT { + e.DraftOption = true + } + case EDNS0UL: + e := new(EDNS0_UL) + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) + off = off1 + int(optlen) + case EDNS0LLQ: + e := new(EDNS0_LLQ) + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) + off = off1 + int(optlen) + case EDNS0DAU: + e := new(EDNS0_DAU) + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) + off = off1 + int(optlen) + case EDNS0DHU: + e := new(EDNS0_DHU) + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) + off = off1 + int(optlen) + case EDNS0N3U: + e := new(EDNS0_N3U) + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) + off = off1 + int(optlen) + default: + // do nothing? + off = off1 + int(optlen) + } + if off < rdend { + goto Option + } + fv.Set(reflect.ValueOf(edns)) + case `dns:"a"`: + if off == lenmsg { + break // dyn. update + } + if off+net.IPv4len > rdend { + return lenmsg, &Error{err: "overflow unpacking a"} + } + fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) + off += net.IPv4len + case `dns:"aaaa"`: + if off == rdend { + break + } + if off+net.IPv6len > rdend || off+net.IPv6len > lenmsg { + return lenmsg, &Error{err: "overflow unpacking aaaa"} + } + fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], + msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10], + msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]})) + off += net.IPv6len + case `dns:"wks"`: + // Rest of the record is the bitmap + serv := make([]uint16, 0) + j := 0 + for off < rdend { + if off+1 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking wks"} + } + b := msg[off] + // Check the bits one by one, and set the type + if b&0x80 == 0x80 { + serv = append(serv, uint16(j*8+0)) + } + if b&0x40 == 0x40 { + serv = append(serv, uint16(j*8+1)) + } + if b&0x20 == 0x20 { + serv = append(serv, uint16(j*8+2)) + } + if b&0x10 == 0x10 { + serv = append(serv, uint16(j*8+3)) + } + if b&0x8 == 0x8 { + serv = append(serv, uint16(j*8+4)) + } + if b&0x4 == 0x4 { + serv = append(serv, uint16(j*8+5)) + } + if b&0x2 == 0x2 { + serv = append(serv, uint16(j*8+6)) + } + if b&0x1 == 0x1 { + serv = append(serv, uint16(j*8+7)) + } + j++ + off++ + } + fv.Set(reflect.ValueOf(serv)) + case `dns:"nsec"`: // NSEC/NSEC3 + if off == rdend { + break + } + // Rest of the record is the type bitmap + if off+2 > rdend || off+2 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking nsecx"} + } + nsec := make([]uint16, 0) + length := 0 + window := 0 + for off+2 < rdend { + window = int(msg[off]) + length = int(msg[off+1]) + //println("off, windows, length, end", off, window, length, endrr) + if length == 0 { + // A length window of zero is strange. If there + // the window should not have been specified. Bail out + // println("dns: length == 0 when unpacking NSEC") + return lenmsg, &Error{err: "overflow unpacking nsecx"} + } + if length > 32 { + return lenmsg, &Error{err: "overflow unpacking nsecx"} + } + + // Walk the bytes in the window - and check the bit settings... + off += 2 + for j := 0; j < length; j++ { + if off+j+1 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking nsecx"} + } + b := msg[off+j] + // Check the bits one by one, and set the type + if b&0x80 == 0x80 { + nsec = append(nsec, uint16(window*256+j*8+0)) + } + if b&0x40 == 0x40 { + nsec = append(nsec, uint16(window*256+j*8+1)) + } + if b&0x20 == 0x20 { + nsec = append(nsec, uint16(window*256+j*8+2)) + } + if b&0x10 == 0x10 { + nsec = append(nsec, uint16(window*256+j*8+3)) + } + if b&0x8 == 0x8 { + nsec = append(nsec, uint16(window*256+j*8+4)) + } + if b&0x4 == 0x4 { + nsec = append(nsec, uint16(window*256+j*8+5)) + } + if b&0x2 == 0x2 { + nsec = append(nsec, uint16(window*256+j*8+6)) + } + if b&0x1 == 0x1 { + nsec = append(nsec, uint16(window*256+j*8+7)) + } + } + off += length + } + fv.Set(reflect.ValueOf(nsec)) + } + case reflect.Struct: + off, err = unpackStructValue(fv, msg, off) + if err != nil { + return lenmsg, err + } + if val.Type().Field(i).Name == "Hdr" { + rdend = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) + } + case reflect.Uint8: + if off == lenmsg { + break + } + if off+1 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking uint8"} + } + fv.SetUint(uint64(uint8(msg[off]))) + off++ + case reflect.Uint16: + if off == lenmsg { + break + } + var i uint16 + if off+2 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking uint16"} + } + i, off = unpackUint16(msg, off) + fv.SetUint(uint64(i)) + case reflect.Uint32: + if off == lenmsg { + break + } + if off+4 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking uint32"} + } + fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))) + off += 4 + case reflect.Uint64: + switch val.Type().Field(i).Tag { + default: + if off+8 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking uint64"} + } + fv.SetUint(uint64(uint64(msg[off])<<56 | uint64(msg[off+1])<<48 | uint64(msg[off+2])<<40 | + uint64(msg[off+3])<<32 | uint64(msg[off+4])<<24 | uint64(msg[off+5])<<16 | uint64(msg[off+6])<<8 | uint64(msg[off+7]))) + off += 8 + case `dns:"uint48"`: + // Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes) + if off+6 > lenmsg { + return lenmsg, &Error{err: "overflow unpacking uint64 as uint48"} + } + fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | + uint64(msg[off+4])<<8 | uint64(msg[off+5]))) + off += 6 + } + case reflect.String: + var s string + if off == lenmsg { + break + } + switch val.Type().Field(i).Tag { + default: + return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} + case `dns:"hex"`: + hexend := rdend + if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { + hexend = off + int(val.FieldByName("HitLength").Uint()) + } + if hexend > rdend || hexend > lenmsg { + return lenmsg, &Error{err: "overflow unpacking hex"} + } + s = hex.EncodeToString(msg[off:hexend]) + off = hexend + case `dns:"base64"`: + // Rest of the RR is base64 encoded value + b64end := rdend + if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { + b64end = off + int(val.FieldByName("PublicKeyLength").Uint()) + } + if b64end > rdend || b64end > lenmsg { + return lenmsg, &Error{err: "overflow unpacking base64"} + } + s = unpackBase64(msg[off:b64end]) + off = b64end + case `dns:"cdomain-name"`: + fallthrough + case `dns:"domain-name"`: + if off == lenmsg { + // zero rdata foo, OK for dyn. updates + break + } + s, off, err = UnpackDomainName(msg, off) + if err != nil { + return lenmsg, err + } + case `dns:"size-base32"`: + var size int + switch val.Type().Name() { + case "NSEC3": + switch val.Type().Field(i).Name { + case "NextDomain": + name := val.FieldByName("HashLength") + size = int(name.Uint()) + } + } + if off+size > lenmsg { + return lenmsg, &Error{err: "overflow unpacking base32"} + } + s = unpackBase32(msg[off : off+size]) + off += size + case `dns:"size-hex"`: + // a "size" string, but it must be encoded in hex in the string + var size int + switch val.Type().Name() { + case "NSEC3": + switch val.Type().Field(i).Name { + case "Salt": + name := val.FieldByName("SaltLength") + size = int(name.Uint()) + case "NextDomain": + name := val.FieldByName("HashLength") + size = int(name.Uint()) + } + case "TSIG": + switch val.Type().Field(i).Name { + case "MAC": + name := val.FieldByName("MACSize") + size = int(name.Uint()) + case "OtherData": + name := val.FieldByName("OtherLen") + size = int(name.Uint()) + } + } + if off+size > lenmsg { + return lenmsg, &Error{err: "overflow unpacking hex"} + } + s = hex.EncodeToString(msg[off : off+size]) + off += size + case `dns:"txt"`: + fallthrough + case "": + s, off, err = unpackTxtString(msg, off) + } + fv.SetString(s) + } + } + return off, nil +} + +// Helpers for dealing with escaped bytes +func isDigit(b byte) bool { return b >= '0' && b <= '9' } + +func dddToByte(s []byte) byte { + return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) +} + +// Helper function for unpacking +func unpackUint16(msg []byte, off int) (v uint16, off1 int) { + v = uint16(msg[off])<<8 | uint16(msg[off+1]) + off1 = off + 2 + return +} + +// UnpackStruct unpacks a binary message from offset off to the interface +// value given. +func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, err error) { + off, err = unpackStructValue(structValue(any), msg, off) + return off, err +} + +func unpackBase32(b []byte) string { + b32 := make([]byte, base32.HexEncoding.EncodedLen(len(b))) + base32.HexEncoding.Encode(b32, b) + return string(b32) +} + +func unpackBase64(b []byte) string { + b64 := make([]byte, base64.StdEncoding.EncodedLen(len(b))) + base64.StdEncoding.Encode(b64, b) + return string(b64) +} + +// Helper function for packing +func packUint16(i uint16) (byte, byte) { + return byte(i >> 8), byte(i) +} + +func packBase64(s []byte) ([]byte, error) { + b64len := base64.StdEncoding.DecodedLen(len(s)) + buf := make([]byte, b64len) + n, err := base64.StdEncoding.Decode(buf, []byte(s)) + if err != nil { + return nil, err + } + buf = buf[:n] + return buf, nil +} + +// Helper function for packing, mostly used in dnssec.go +func packBase32(s []byte) ([]byte, error) { + b32len := base32.HexEncoding.DecodedLen(len(s)) + buf := make([]byte, b32len) + n, err := base32.HexEncoding.Decode(buf, []byte(s)) + if err != nil { + return nil, err + } + buf = buf[:n] + return buf, nil +} + +// PackRR packs a resource record rr into msg[off:]. +// See PackDomainName for documentation about the compression. +func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { + if rr == nil { + return len(msg), &Error{err: "nil rr"} + } + + off1, err = packStructCompress(rr, msg, off, compression, compress) + if err != nil { + return len(msg), err + } + if rawSetRdlength(msg, off, off1) { + return off1, nil + } + return off, ErrRdata +} + +// UnpackRR unpacks msg[off:] into an RR. +func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) { + // unpack just the header, to find the rr type and length + var h RR_Header + off0 := off + if off, err = UnpackStruct(&h, msg, off); err != nil { + return nil, len(msg), err + } + end := off + int(h.Rdlength) + // make an rr of that type and re-unpack. + mk, known := typeToRR[h.Rrtype] + if !known { + rr = new(RFC3597) + } else { + rr = mk() + } + off, err = UnpackStruct(rr, msg, off0) + if off != end { + return &h, end, &Error{err: "bad rdlength"} + } + return rr, off, err +} + +// Reverse a map +func reverseInt8(m map[uint8]string) map[string]uint8 { + n := make(map[string]uint8) + for u, s := range m { + n[s] = u + } + return n +} + +func reverseInt16(m map[uint16]string) map[string]uint16 { + n := make(map[string]uint16) + for u, s := range m { + n[s] = u + } + return n +} + +func reverseInt(m map[int]string) map[string]int { + n := make(map[string]int) + for u, s := range m { + n[s] = u + } + return n +} + +// Convert a MsgHdr to a string, with dig-like headers: +// +//;; opcode: QUERY, status: NOERROR, id: 48404 +// +//;; flags: qr aa rd ra; +func (h *MsgHdr) String() string { + if h == nil { + return " MsgHdr" + } + + s := ";; opcode: " + OpcodeToString[h.Opcode] + s += ", status: " + RcodeToString[h.Rcode] + s += ", id: " + strconv.Itoa(int(h.Id)) + "\n" + + s += ";; flags:" + if h.Response { + s += " qr" + } + if h.Authoritative { + s += " aa" + } + if h.Truncated { + s += " tc" + } + if h.RecursionDesired { + s += " rd" + } + if h.RecursionAvailable { + s += " ra" + } + if h.Zero { // Hmm + s += " z" + } + if h.AuthenticatedData { + s += " ad" + } + if h.CheckingDisabled { + s += " cd" + } + + s += ";" + return s +} + +// Pack packs a Msg: it is converted to to wire format. +// If the dns.Compress is true the message will be in compressed wire format. +func (dns *Msg) Pack() (msg []byte, err error) { + return dns.PackBuffer(nil) +} + +// PackBuffer packs a Msg, using the given buffer buf. If buf is too small +// a new buffer is allocated. +func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { + var dh Header + var compression map[string]int + if dns.Compress { + compression = make(map[string]int) // Compression pointer mappings + } + + if dns.Rcode < 0 || dns.Rcode > 0xFFF { + return nil, ErrRcode + } + if dns.Rcode > 0xF { + // Regular RCODE field is 4 bits + opt := dns.IsEdns0() + if opt == nil { + return nil, ErrExtendedRcode + } + opt.SetExtendedRcode(uint8(dns.Rcode >> 4)) + dns.Rcode &= 0xF + } + + // Convert convenient Msg into wire-like Header. + dh.Id = dns.Id + dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode) + if dns.Response { + dh.Bits |= _QR + } + if dns.Authoritative { + dh.Bits |= _AA + } + if dns.Truncated { + dh.Bits |= _TC + } + if dns.RecursionDesired { + dh.Bits |= _RD + } + if dns.RecursionAvailable { + dh.Bits |= _RA + } + if dns.Zero { + dh.Bits |= _Z + } + if dns.AuthenticatedData { + dh.Bits |= _AD + } + if dns.CheckingDisabled { + dh.Bits |= _CD + } + + // Prepare variable sized arrays. + question := dns.Question + answer := dns.Answer + ns := dns.Ns + extra := dns.Extra + + dh.Qdcount = uint16(len(question)) + dh.Ancount = uint16(len(answer)) + dh.Nscount = uint16(len(ns)) + dh.Arcount = uint16(len(extra)) + + // We need the uncompressed length here, because we first pack it and then compress it. + msg = buf + compress := dns.Compress + dns.Compress = false + if packLen := dns.Len() + 1; len(msg) < packLen { + msg = make([]byte, packLen) + } + dns.Compress = compress + + // Pack it in: header and then the pieces. + off := 0 + off, err = packStructCompress(&dh, msg, off, compression, dns.Compress) + if err != nil { + return nil, err + } + for i := 0; i < len(question); i++ { + off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress) + if err != nil { + return nil, err + } + } + for i := 0; i < len(answer); i++ { + off, err = PackRR(answer[i], msg, off, compression, dns.Compress) + if err != nil { + return nil, err + } + } + for i := 0; i < len(ns); i++ { + off, err = PackRR(ns[i], msg, off, compression, dns.Compress) + if err != nil { + return nil, err + } + } + for i := 0; i < len(extra); i++ { + off, err = PackRR(extra[i], msg, off, compression, dns.Compress) + if err != nil { + return nil, err + } + } + return msg[:off], nil +} + +// Unpack unpacks a binary message to a Msg structure. +func (dns *Msg) Unpack(msg []byte) (err error) { + // Header. + var dh Header + off := 0 + if off, err = UnpackStruct(&dh, msg, off); err != nil { + return err + } + dns.Id = dh.Id + dns.Response = (dh.Bits & _QR) != 0 + dns.Opcode = int(dh.Bits>>11) & 0xF + dns.Authoritative = (dh.Bits & _AA) != 0 + dns.Truncated = (dh.Bits & _TC) != 0 + dns.RecursionDesired = (dh.Bits & _RD) != 0 + dns.RecursionAvailable = (dh.Bits & _RA) != 0 + dns.Zero = (dh.Bits & _Z) != 0 + dns.AuthenticatedData = (dh.Bits & _AD) != 0 + dns.CheckingDisabled = (dh.Bits & _CD) != 0 + dns.Rcode = int(dh.Bits & 0xF) + + // Arrays. + dns.Question = make([]Question, dh.Qdcount) + dns.Answer = make([]RR, dh.Ancount) + dns.Ns = make([]RR, dh.Nscount) + dns.Extra = make([]RR, dh.Arcount) + + for i := 0; i < len(dns.Question); i++ { + off, err = UnpackStruct(&dns.Question[i], msg, off) + if err != nil { + return err + } + } + // If we see a TC bit being set we return here, without + // an error, because technically it isn't an error. So return + // without parsing the potentially corrupt packet and hitting an error. + // TODO(miek): this isn't the best strategy! + if dns.Truncated { + dns.Answer = nil + dns.Ns = nil + dns.Extra = nil + return nil + } + for i := 0; i < len(dns.Answer); i++ { + dns.Answer[i], off, err = UnpackRR(msg, off) + if err != nil { + return err + } + } + for i := 0; i < len(dns.Ns); i++ { + dns.Ns[i], off, err = UnpackRR(msg, off) + if err != nil { + return err + } + } + for i := 0; i < len(dns.Extra); i++ { + dns.Extra[i], off, err = UnpackRR(msg, off) + if err != nil { + return err + } + } + if off != len(msg) { + // TODO(miek) make this an error? + // use PackOpt to let people tell how detailed the error reporting should be? + // println("dns: extra bytes in dns packet", off, "<", len(msg)) + } + return nil +} + +// Convert a complete message to a string with dig-like output. +func (dns *Msg) String() string { + if dns == nil { + return " MsgHdr" + } + s := dns.MsgHdr.String() + " " + s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", " + s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", " + s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", " + s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n" + if len(dns.Question) > 0 { + s += "\n;; QUESTION SECTION:\n" + for i := 0; i < len(dns.Question); i++ { + s += dns.Question[i].String() + "\n" + } + } + if len(dns.Answer) > 0 { + s += "\n;; ANSWER SECTION:\n" + for i := 0; i < len(dns.Answer); i++ { + if dns.Answer[i] != nil { + s += dns.Answer[i].String() + "\n" + } + } + } + if len(dns.Ns) > 0 { + s += "\n;; AUTHORITY SECTION:\n" + for i := 0; i < len(dns.Ns); i++ { + if dns.Ns[i] != nil { + s += dns.Ns[i].String() + "\n" + } + } + } + if len(dns.Extra) > 0 { + s += "\n;; ADDITIONAL SECTION:\n" + for i := 0; i < len(dns.Extra); i++ { + if dns.Extra[i] != nil { + s += dns.Extra[i].String() + "\n" + } + } + } + return s +} + +// Len returns the message length when in (un)compressed wire format. +// If dns.Compress is true compression it is taken into account. Len() +// is provided to be a faster way to get the size of the resulting packet, +// than packing it, measuring the size and discarding the buffer. +func (dns *Msg) Len() int { + // We always return one more than needed. + l := 12 // Message header is always 12 bytes + var compression map[string]int + if dns.Compress { + compression = make(map[string]int) + } + for i := 0; i < len(dns.Question); i++ { + l += dns.Question[i].len() + if dns.Compress { + compressionLenHelper(compression, dns.Question[i].Name) + } + } + for i := 0; i < len(dns.Answer); i++ { + l += dns.Answer[i].len() + if dns.Compress { + k, ok := compressionLenSearch(compression, dns.Answer[i].Header().Name) + if ok { + l += 1 - k + } + compressionLenHelper(compression, dns.Answer[i].Header().Name) + k, ok = compressionLenSearchType(compression, dns.Answer[i]) + if ok { + l += 1 - k + } + compressionLenHelperType(compression, dns.Answer[i]) + } + } + for i := 0; i < len(dns.Ns); i++ { + l += dns.Ns[i].len() + if dns.Compress { + k, ok := compressionLenSearch(compression, dns.Ns[i].Header().Name) + if ok { + l += 1 - k + } + compressionLenHelper(compression, dns.Ns[i].Header().Name) + k, ok = compressionLenSearchType(compression, dns.Ns[i]) + if ok { + l += 1 - k + } + compressionLenHelperType(compression, dns.Ns[i]) + } + } + for i := 0; i < len(dns.Extra); i++ { + l += dns.Extra[i].len() + if dns.Compress { + k, ok := compressionLenSearch(compression, dns.Extra[i].Header().Name) + if ok { + l += 1 - k + } + compressionLenHelper(compression, dns.Extra[i].Header().Name) + k, ok = compressionLenSearchType(compression, dns.Extra[i]) + if ok { + l += 1 - k + } + compressionLenHelperType(compression, dns.Extra[i]) + } + } + return l +} + +// Put the parts of the name in the compression map. +func compressionLenHelper(c map[string]int, s string) { + pref := "" + lbs := Split(s) + for j := len(lbs) - 1; j >= 0; j-- { + pref = s[lbs[j]:] + if _, ok := c[pref]; !ok { + c[pref] = len(pref) + } + } +} + +// Look for each part in the compression map and returns its length, +// keep on searching so we get the longest match. +func compressionLenSearch(c map[string]int, s string) (int, bool) { + off := 0 + end := false + if s == "" { // don't bork on bogus data + return 0, false + } + for { + if _, ok := c[s[off:]]; ok { + return len(s[off:]), true + } + if end { + break + } + off, end = NextLabel(s, off) + } + return 0, false +} + +// TODO(miek): should add all types, because the all can be *used* for compression. +func compressionLenHelperType(c map[string]int, r RR) { + switch x := r.(type) { + case *NS: + compressionLenHelper(c, x.Ns) + case *MX: + compressionLenHelper(c, x.Mx) + case *CNAME: + compressionLenHelper(c, x.Target) + case *PTR: + compressionLenHelper(c, x.Ptr) + case *SOA: + compressionLenHelper(c, x.Ns) + compressionLenHelper(c, x.Mbox) + case *MB: + compressionLenHelper(c, x.Mb) + case *MG: + compressionLenHelper(c, x.Mg) + case *MR: + compressionLenHelper(c, x.Mr) + case *MF: + compressionLenHelper(c, x.Mf) + case *MD: + compressionLenHelper(c, x.Md) + case *RT: + compressionLenHelper(c, x.Host) + case *MINFO: + compressionLenHelper(c, x.Rmail) + compressionLenHelper(c, x.Email) + case *AFSDB: + compressionLenHelper(c, x.Hostname) + } +} + +// Only search on compressing these types. +func compressionLenSearchType(c map[string]int, r RR) (int, bool) { + switch x := r.(type) { + case *NS: + return compressionLenSearch(c, x.Ns) + case *MX: + return compressionLenSearch(c, x.Mx) + case *CNAME: + return compressionLenSearch(c, x.Target) + case *PTR: + return compressionLenSearch(c, x.Ptr) + case *SOA: + k, ok := compressionLenSearch(c, x.Ns) + k1, ok1 := compressionLenSearch(c, x.Mbox) + if !ok && !ok1 { + return 0, false + } + return k + k1, true + case *MB: + return compressionLenSearch(c, x.Mb) + case *MG: + return compressionLenSearch(c, x.Mg) + case *MR: + return compressionLenSearch(c, x.Mr) + case *MF: + return compressionLenSearch(c, x.Mf) + case *MD: + return compressionLenSearch(c, x.Md) + case *RT: + return compressionLenSearch(c, x.Host) + case *MINFO: + k, ok := compressionLenSearch(c, x.Rmail) + k1, ok1 := compressionLenSearch(c, x.Email) + if !ok && !ok1 { + return 0, false + } + return k + k1, true + case *AFSDB: + return compressionLenSearch(c, x.Hostname) + } + return 0, false +} + +// id returns a 16 bits random number to be used as a +// message id. The random provided should be good enough. +func id() uint16 { + return uint16(rand.Int()) ^ uint16(time.Now().Nanosecond()) +} + +// Copy returns a new RR which is a deep-copy of r. +func Copy(r RR) RR { + r1 := r.copy() + return r1 +} + +// Copy returns a new *Msg which is a deep-copy of dns. +func (dns *Msg) Copy() *Msg { + r1 := new(Msg) + r1.MsgHdr = dns.MsgHdr + r1.Compress = dns.Compress + + if len(dns.Question) > 0 { + r1.Question = make([]Question, len(dns.Question)) + copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy + } + + if len(dns.Answer) > 0 { + r1.Answer = make([]RR, len(dns.Answer)) + for i := 0; i < len(dns.Answer); i++ { + r1.Answer[i] = dns.Answer[i].copy() + } + } + + if len(dns.Ns) > 0 { + r1.Ns = make([]RR, len(dns.Ns)) + for i := 0; i < len(dns.Ns); i++ { + r1.Ns[i] = dns.Ns[i].copy() + } + } + + if len(dns.Extra) > 0 { + r1.Extra = make([]RR, len(dns.Extra)) + for i := 0; i < len(dns.Extra); i++ { + r1.Extra[i] = dns.Extra[i].copy() + } + } + + return r1 +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/nsecx.go b/Godeps/_workspace/src/github.com/miekg/dns/nsecx.go new file mode 100644 index 0000000000..28bf08d9be --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/nsecx.go @@ -0,0 +1,110 @@ +package dns + +import ( + "crypto/sha1" + "hash" + "io" + "strings" +) + +type saltWireFmt struct { + Salt string `dns:"size-hex"` +} + +// HashName hashes a string (label) according to RFC 5155. It returns the hashed string in +// uppercase. +func HashName(label string, ha uint8, iter uint16, salt string) string { + saltwire := new(saltWireFmt) + saltwire.Salt = salt + wire := make([]byte, DefaultMsgSize) + n, err := PackStruct(saltwire, wire, 0) + if err != nil { + return "" + } + wire = wire[:n] + name := make([]byte, 255) + off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false) + if err != nil { + return "" + } + name = name[:off] + var s hash.Hash + switch ha { + case SHA1: + s = sha1.New() + default: + return "" + } + + // k = 0 + name = append(name, wire...) + io.WriteString(s, string(name)) + nsec3 := s.Sum(nil) + // k > 0 + for k := uint16(0); k < iter; k++ { + s.Reset() + nsec3 = append(nsec3, wire...) + io.WriteString(s, string(nsec3)) + nsec3 = s.Sum(nil) + } + return unpackBase32(nsec3) +} + +type Denialer interface { + // Cover will check if the (unhashed) name is being covered by this NSEC or NSEC3. + Cover(name string) bool + // Match will check if the ownername matches the (unhashed) name for this NSEC3 or NSEC3. + Match(name string) bool +} + +// Cover implements the Denialer interface. +func (rr *NSEC) Cover(name string) bool { + return true +} + +// Match implements the Denialer interface. +func (rr *NSEC) Match(name string) bool { + return true +} + +// Cover implements the Denialer interface. +func (rr *NSEC3) Cover(name string) bool { + // FIXME(miek): check if the zones match + // FIXME(miek): check if we're not dealing with parent nsec3 + hname := HashName(name, rr.Hash, rr.Iterations, rr.Salt) + labels := Split(rr.Hdr.Name) + if len(labels) < 2 { + return false + } + hash := strings.ToUpper(rr.Hdr.Name[labels[0] : labels[1]-1]) // -1 to remove the dot + if hash == rr.NextDomain { + return false // empty interval + } + if hash > rr.NextDomain { // last name, points to apex + // hname > hash + // hname > rr.NextDomain + // TODO(miek) + } + if hname <= hash { + return false + } + if hname >= rr.NextDomain { + return false + } + return true +} + +// Match implements the Denialer interface. +func (rr *NSEC3) Match(name string) bool { + // FIXME(miek): Check if we are in the same zone + hname := HashName(name, rr.Hash, rr.Iterations, rr.Salt) + labels := Split(rr.Hdr.Name) + if len(labels) < 2 { + return false + } + hash := strings.ToUpper(rr.Hdr.Name[labels[0] : labels[1]-1]) // -1 to remove the . + if hash == hname { + return true + } + return false +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/nsecx_test.go b/Godeps/_workspace/src/github.com/miekg/dns/nsecx_test.go new file mode 100644 index 0000000000..72f641a667 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/nsecx_test.go @@ -0,0 +1,33 @@ +package dns + +import ( + "testing" +) + +func TestPackNsec3(t *testing.T) { + nsec3 := HashName("dnsex.nl.", SHA1, 0, "DEAD") + if nsec3 != "ROCCJAE8BJJU7HN6T7NG3TNM8ACRS87J" { + t.Logf("%v\n", nsec3) + t.Fail() + } + + nsec3 = HashName("a.b.c.example.org.", SHA1, 2, "DEAD") + if nsec3 != "6LQ07OAHBTOOEU2R9ANI2AT70K5O0RCG" { + t.Logf("%v\n", nsec3) + t.Fail() + } +} + +func TestNsec3(t *testing.T) { + // examples taken from .nl + nsec3, _ := NewRR("39p91242oslggest5e6a7cci4iaeqvnk.nl. IN NSEC3 1 1 5 F10E9F7EA83FC8F3 39P99DCGG0MDLARTCRMCF6OFLLUL7PR6 NS DS RRSIG") + if !nsec3.(*NSEC3).Cover("snasajsksasasa.nl.") { // 39p94jrinub66hnpem8qdpstrec86pg3 + t.Logf("39p94jrinub66hnpem8qdpstrec86pg3. should be covered by 39p91242oslggest5e6a7cci4iaeqvnk.nl. - 39P99DCGG0MDLARTCRMCF6OFLLUL7PR6") + t.Fail() + } + nsec3, _ = NewRR("sk4e8fj94u78smusb40o1n0oltbblu2r.nl. IN NSEC3 1 1 5 F10E9F7EA83FC8F3 SK4F38CQ0ATIEI8MH3RGD0P5I4II6QAN NS SOA TXT RRSIG DNSKEY NSEC3PARAM") + if !nsec3.(*NSEC3).Match("nl.") { // sk4e8fj94u78smusb40o1n0oltbblu2r.nl. + t.Logf("sk4e8fj94u78smusb40o1n0oltbblu2r.nl. should match sk4e8fj94u78smusb40o1n0oltbblu2r.nl.") + t.Fail() + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/parse_test.go b/Godeps/_workspace/src/github.com/miekg/dns/parse_test.go new file mode 100644 index 0000000000..3c18b38d69 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/parse_test.go @@ -0,0 +1,1287 @@ +package dns + +import ( + "bytes" + "crypto/rsa" + "encoding/hex" + "fmt" + "math/rand" + "net" + "reflect" + "strconv" + "strings" + "testing" + "testing/quick" + "time" +) + +func TestDotInName(t *testing.T) { + buf := make([]byte, 20) + PackDomainName("aa\\.bb.nl.", buf, 0, nil, false) + // index 3 must be a real dot + if buf[3] != '.' { + t.Log("dot should be a real dot") + t.Fail() + } + + if buf[6] != 2 { + t.Log("this must have the value 2") + t.Fail() + } + dom, _, _ := UnpackDomainName(buf, 0) + // printing it should yield the backspace again + if dom != "aa\\.bb.nl." { + t.Log("dot should have been escaped: " + dom) + t.Fail() + } +} + +func TestDotLastInLabel(t *testing.T) { + sample := "aa\\..au." + buf := make([]byte, 20) + _, err := PackDomainName(sample, buf, 0, nil, false) + if err != nil { + t.Fatalf("unexpected error packing domain: %s", err) + } + dom, _, _ := UnpackDomainName(buf, 0) + if dom != sample { + t.Fatalf("unpacked domain `%s' doesn't match packed domain", dom) + } +} + +func TestTooLongDomainName(t *testing.T) { + l := "aaabbbcccdddeeefffggghhhiiijjjkkklllmmmnnnooopppqqqrrrsssttt." + dom := l + l + l + l + l + l + l + _, e := NewRR(dom + " IN A 127.0.0.1") + if e == nil { + t.Log("should be too long") + t.Fail() + } else { + t.Logf("error is %s", e.Error()) + } + _, e = NewRR("..com. IN A 127.0.0.1") + if e == nil { + t.Log("should fail") + t.Fail() + } else { + t.Logf("error is %s", e.Error()) + } +} + +func TestDomainName(t *testing.T) { + tests := []string{"r\\.gieben.miek.nl.", "www\\.www.miek.nl.", + "www.*.miek.nl.", "www.*.miek.nl.", + } + dbuff := make([]byte, 40) + + for _, ts := range tests { + if _, err := PackDomainName(ts, dbuff, 0, nil, false); err != nil { + t.Log("not a valid domain name") + t.Fail() + continue + } + n, _, err := UnpackDomainName(dbuff, 0) + if err != nil { + t.Log("failed to unpack packed domain name") + t.Fail() + continue + } + if ts != n { + t.Logf("must be equal: in: %s, out: %s\n", ts, n) + t.Fail() + } + } +} + +func TestDomainNameAndTXTEscapes(t *testing.T) { + tests := []byte{'.', '(', ')', ';', ' ', '@', '"', '\\', '\t', '\r', '\n', 0, 255} + for _, b := range tests { + rrbytes := []byte{ + 1, b, 0, // owner + byte(TypeTXT >> 8), byte(TypeTXT), + byte(ClassINET >> 8), byte(ClassINET), + 0, 0, 0, 1, // TTL + 0, 2, 1, b, // Data + } + rr1, _, err := UnpackRR(rrbytes, 0) + if err != nil { + panic(err) + } + s := rr1.String() + rr2, err := NewRR(s) + if err != nil { + t.Logf("Error parsing unpacked RR's string: %v", err) + t.Logf(" Bytes: %v\n", rrbytes) + t.Logf("String: %v\n", s) + t.Fail() + } + repacked := make([]byte, len(rrbytes)) + if _, err := PackRR(rr2, repacked, 0, nil, false); err != nil { + t.Logf("error packing parsed RR: %v", err) + t.Logf(" original Bytes: %v\n", rrbytes) + t.Logf("unpacked Struct: %V\n", rr1) + t.Logf(" parsed Struct: %V\n", rr2) + t.Fail() + } + if !bytes.Equal(repacked, rrbytes) { + t.Log("packed bytes don't match original bytes") + t.Logf(" original bytes: %v", rrbytes) + t.Logf(" packed bytes: %v", repacked) + t.Logf("unpacked struct: %V", rr1) + t.Logf(" parsed struct: %V", rr2) + t.Fail() + } + } +} + +func TestTXTEscapeParsing(t *testing.T) { + test := [][]string{ + {`";"`, `";"`}, + {`\;`, `";"`}, + {`"\t"`, `"\t"`}, + {`"\r"`, `"\r"`}, + {`"\ "`, `" "`}, + {`"\;"`, `";"`}, + {`"\;\""`, `";\""`}, + {`"\(a\)"`, `"(a)"`}, + {`"\(a)"`, `"(a)"`}, + {`"(a\)"`, `"(a)"`}, + {`"(a)"`, `"(a)"`}, + {`"\048"`, `"0"`}, + {`"\` + "\n" + `"`, `"\n"`}, + {`"\` + "\r" + `"`, `"\r"`}, + {`"\` + "\x11" + `"`, `"\017"`}, + {`"\'"`, `"'"`}, + } + for _, s := range test { + rr, err := NewRR(fmt.Sprintf("example.com. IN TXT %v", s[0])) + if err != nil { + t.Errorf("Could not parse %v TXT: %s", s[0], err) + continue + } + + txt := sprintTxt(rr.(*TXT).Txt) + if txt != s[1] { + t.Errorf("Mismatch after parsing `%v` TXT record: `%v` != `%v`", s[0], txt, s[1]) + } + } +} + +func GenerateDomain(r *rand.Rand, size int) []byte { + dnLen := size % 70 // artificially limit size so there's less to intrepret if a failure occurs + var dn []byte + done := false + for i := 0; i < dnLen && !done; { + max := dnLen - i + if max > 63 { + max = 63 + } + lLen := max + if lLen != 0 { + lLen = int(r.Int31()) % max + } + done = lLen == 0 + if done { + continue + } + l := make([]byte, lLen+1) + l[0] = byte(lLen) + for j := 0; j < lLen; j++ { + l[j+1] = byte(rand.Int31()) + } + dn = append(dn, l...) + i += 1 + lLen + } + return append(dn, 0) +} + +func TestDomainQuick(t *testing.T) { + r := rand.New(rand.NewSource(0)) + f := func(l int) bool { + db := GenerateDomain(r, l) + ds, _, err := UnpackDomainName(db, 0) + if err != nil { + panic(err) + } + buf := make([]byte, 255) + off, err := PackDomainName(ds, buf, 0, nil, false) + if err != nil { + t.Logf("error packing domain: %s", err.Error()) + t.Logf(" bytes: %v\n", db) + t.Logf("string: %v\n", ds) + return false + } + if !bytes.Equal(db, buf[:off]) { + t.Logf("repacked domain doesn't match original:") + t.Logf("src bytes: %v", db) + t.Logf(" string: %v", ds) + t.Logf("out bytes: %v", buf[:off]) + return false + } + return true + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + +func GenerateTXT(r *rand.Rand, size int) []byte { + rdLen := size % 300 // artificially limit size so there's less to intrepret if a failure occurs + var rd []byte + for i := 0; i < rdLen; { + max := rdLen - 1 + if max > 255 { + max = 255 + } + sLen := max + if max != 0 { + sLen = int(r.Int31()) % max + } + s := make([]byte, sLen+1) + s[0] = byte(sLen) + for j := 0; j < sLen; j++ { + s[j+1] = byte(rand.Int31()) + } + rd = append(rd, s...) + i += 1 + sLen + } + return rd +} + +func TestTXTRRQuick(t *testing.T) { + s := rand.NewSource(0) + r := rand.New(s) + typeAndClass := []byte{ + byte(TypeTXT >> 8), byte(TypeTXT), + byte(ClassINET >> 8), byte(ClassINET), + 0, 0, 0, 1, // TTL + } + f := func(l int) bool { + owner := GenerateDomain(r, l) + rdata := GenerateTXT(r, l) + rrbytes := make([]byte, 0, len(owner)+2+2+4+2+len(rdata)) + rrbytes = append(rrbytes, owner...) + rrbytes = append(rrbytes, typeAndClass...) + rrbytes = append(rrbytes, byte(len(rdata)>>8)) + rrbytes = append(rrbytes, byte(len(rdata))) + rrbytes = append(rrbytes, rdata...) + rr, _, err := UnpackRR(rrbytes, 0) + if err != nil { + panic(err) + } + buf := make([]byte, len(rrbytes)*3) + off, err := PackRR(rr, buf, 0, nil, false) + if err != nil { + t.Logf("pack Error: %s\nRR: %V", err.Error(), rr) + return false + } + buf = buf[:off] + if !bytes.Equal(buf, rrbytes) { + t.Logf("packed bytes don't match original bytes") + t.Logf("src bytes: %v", rrbytes) + t.Logf(" struct: %V", rr) + t.Logf("oUt bytes: %v", buf) + return false + } + if len(rdata) == 0 { + // string'ing won't produce any data to parse + return true + } + rrString := rr.String() + rr2, err := NewRR(rrString) + if err != nil { + t.Logf("error parsing own output: %s", err.Error()) + t.Logf("struct: %V", rr) + t.Logf("string: %v", rrString) + return false + } + if rr2.String() != rrString { + t.Logf("parsed rr.String() doesn't match original string") + t.Logf("original: %v", rrString) + t.Logf(" parsed: %v", rr2.String()) + return false + } + + buf = make([]byte, len(rrbytes)*3) + off, err = PackRR(rr2, buf, 0, nil, false) + if err != nil { + t.Logf("error packing parsed rr: %s", err.Error()) + t.Logf("unpacked Struct: %V", rr) + t.Logf(" string: %v", rrString) + t.Logf(" parsed Struct: %V", rr2) + return false + } + buf = buf[:off] + if !bytes.Equal(buf, rrbytes) { + t.Logf("parsed packed bytes don't match original bytes") + t.Logf(" source bytes: %v", rrbytes) + t.Logf("unpacked struct: %V", rr) + t.Logf(" string: %v", rrString) + t.Logf(" parsed struct: %V", rr2) + t.Logf(" repacked bytes: %v", buf) + return false + } + return true + } + c := &quick.Config{MaxCountScale: 10} + if err := quick.Check(f, c); err != nil { + t.Error(err) + } +} + +func TestParseDirectiveMisc(t *testing.T) { + tests := map[string]string{ + "$ORIGIN miek.nl.\na IN NS b": "a.miek.nl.\t3600\tIN\tNS\tb.miek.nl.", + "$TTL 2H\nmiek.nl. IN NS b.": "miek.nl.\t7200\tIN\tNS\tb.", + "miek.nl. 1D IN NS b.": "miek.nl.\t86400\tIN\tNS\tb.", + `name. IN SOA a6.nstld.com. hostmaster.nic.name. ( + 203362132 ; serial + 5m ; refresh (5 minutes) + 5m ; retry (5 minutes) + 2w ; expire (2 weeks) + 300 ; minimum (5 minutes) +)`: "name.\t3600\tIN\tSOA\ta6.nstld.com. hostmaster.nic.name. 203362132 300 300 1209600 300", + ". 3600000 IN NS ONE.MY-ROOTS.NET.": ".\t3600000\tIN\tNS\tONE.MY-ROOTS.NET.", + "ONE.MY-ROOTS.NET. 3600000 IN A 192.168.1.1": "ONE.MY-ROOTS.NET.\t3600000\tIN\tA\t192.168.1.1", + } + for i, o := range tests { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is `%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestNSEC(t *testing.T) { + nsectests := map[string]string{ + "nl. IN NSEC3PARAM 1 0 5 30923C44C6CBBB8F": "nl.\t3600\tIN\tNSEC3PARAM\t1 0 5 30923C44C6CBBB8F", + "p2209hipbpnm681knjnu0m1febshlv4e.nl. IN NSEC3 1 1 5 30923C44C6CBBB8F P90DG1KE8QEAN0B01613LHQDG0SOJ0TA NS SOA TXT RRSIG DNSKEY NSEC3PARAM": "p2209hipbpnm681knjnu0m1febshlv4e.nl.\t3600\tIN\tNSEC3\t1 1 5 30923C44C6CBBB8F P90DG1KE8QEAN0B01613LHQDG0SOJ0TA NS SOA TXT RRSIG DNSKEY NSEC3PARAM", + "localhost.dnssex.nl. IN NSEC www.dnssex.nl. A RRSIG NSEC": "localhost.dnssex.nl.\t3600\tIN\tNSEC\twww.dnssex.nl. A RRSIG NSEC", + "localhost.dnssex.nl. IN NSEC www.dnssex.nl. A RRSIG NSEC TYPE65534": "localhost.dnssex.nl.\t3600\tIN\tNSEC\twww.dnssex.nl. A RRSIG NSEC TYPE65534", + "localhost.dnssex.nl. IN NSEC www.dnssex.nl. A RRSIG NSec Type65534": "localhost.dnssex.nl.\t3600\tIN\tNSEC\twww.dnssex.nl. A RRSIG NSEC TYPE65534", + } + for i, o := range nsectests { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is `%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestParseLOC(t *testing.T) { + lt := map[string]string{ + "SW1A2AA.find.me.uk. LOC 51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 30 12.748 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m", + "SW1A2AA.find.me.uk. LOC 51 0 0.0 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m": "SW1A2AA.find.me.uk.\t3600\tIN\tLOC\t51 00 0.000 N 00 07 39.611 W 0.00m 0.00m 0.00m 0.00m", + } + for i, o := range lt { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is `%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestParseDS(t *testing.T) { + dt := map[string]string{ + "example.net. 3600 IN DS 40692 12 3 22261A8B0E0D799183E35E24E2AD6BB58533CBA7E3B14D659E9CA09B 2071398F": "example.net.\t3600\tIN\tDS\t40692 12 3 22261A8B0E0D799183E35E24E2AD6BB58533CBA7E3B14D659E9CA09B2071398F", + } + for i, o := range dt { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is `%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestQuotes(t *testing.T) { + tests := map[string]string{ + `t.example.com. IN TXT "a bc"`: "t.example.com.\t3600\tIN\tTXT\t\"a bc\"", + `t.example.com. IN TXT "a + bc"`: "t.example.com.\t3600\tIN\tTXT\t\"a\\n bc\"", + `t.example.com. IN TXT ""`: "t.example.com.\t3600\tIN\tTXT\t\"\"", + `t.example.com. IN TXT "a"`: "t.example.com.\t3600\tIN\tTXT\t\"a\"", + `t.example.com. IN TXT "aa"`: "t.example.com.\t3600\tIN\tTXT\t\"aa\"", + `t.example.com. IN TXT "aaa" ;`: "t.example.com.\t3600\tIN\tTXT\t\"aaa\"", + `t.example.com. IN TXT "abc" "DEF"`: "t.example.com.\t3600\tIN\tTXT\t\"abc\" \"DEF\"", + `t.example.com. IN TXT "abc" ( "DEF" )`: "t.example.com.\t3600\tIN\tTXT\t\"abc\" \"DEF\"", + `t.example.com. IN TXT aaa ;`: "t.example.com.\t3600\tIN\tTXT\t\"aaa \"", + `t.example.com. IN TXT aaa aaa;`: "t.example.com.\t3600\tIN\tTXT\t\"aaa aaa\"", + `t.example.com. IN TXT aaa aaa`: "t.example.com.\t3600\tIN\tTXT\t\"aaa aaa\"", + `t.example.com. IN TXT aaa`: "t.example.com.\t3600\tIN\tTXT\t\"aaa\"", + "cid.urn.arpa. NAPTR 100 50 \"s\" \"z3950+I2L+I2C\" \"\" _z3950._tcp.gatech.edu.": "cid.urn.arpa.\t3600\tIN\tNAPTR\t100 50 \"s\" \"z3950+I2L+I2C\" \"\" _z3950._tcp.gatech.edu.", + "cid.urn.arpa. NAPTR 100 50 \"s\" \"rcds+I2C\" \"\" _rcds._udp.gatech.edu.": "cid.urn.arpa.\t3600\tIN\tNAPTR\t100 50 \"s\" \"rcds+I2C\" \"\" _rcds._udp.gatech.edu.", + "cid.urn.arpa. NAPTR 100 50 \"s\" \"http+I2L+I2C+I2R\" \"\" _http._tcp.gatech.edu.": "cid.urn.arpa.\t3600\tIN\tNAPTR\t100 50 \"s\" \"http+I2L+I2C+I2R\" \"\" _http._tcp.gatech.edu.", + "cid.urn.arpa. NAPTR 100 10 \"\" \"\" \"/urn:cid:.+@([^\\.]+\\.)(.*)$/\\2/i\" .": "cid.urn.arpa.\t3600\tIN\tNAPTR\t100 10 \"\" \"\" \"/urn:cid:.+@([^\\.]+\\.)(.*)$/\\2/i\" .", + } + for i, o := range tests { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is\n`%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestParseClass(t *testing.T) { + tests := map[string]string{ + "t.example.com. IN A 127.0.0.1": "t.example.com. 3600 IN A 127.0.0.1", + "t.example.com. CS A 127.0.0.1": "t.example.com. 3600 CS A 127.0.0.1", + "t.example.com. CH A 127.0.0.1": "t.example.com. 3600 CH A 127.0.0.1", + // ClassANY can not occur in zone files + // "t.example.com. ANY A 127.0.0.1": "t.example.com. 3600 ANY A 127.0.0.1", + "t.example.com. NONE A 127.0.0.1": "t.example.com. 3600 NONE A 127.0.0.1", + } + for i, o := range tests { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is\n`%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestBrace(t *testing.T) { + tests := map[string]string{ + "(miek.nl.) 3600 IN A 127.0.1.1": "miek.nl.\t3600\tIN\tA\t127.0.1.1", + "miek.nl. (3600) IN MX (10) elektron.atoom.net.": "miek.nl.\t3600\tIN\tMX\t10 elektron.atoom.net.", + `miek.nl. IN ( + 3600 A 127.0.0.1)`: "miek.nl.\t3600\tIN\tA\t127.0.0.1", + "(miek.nl.) (A) (127.0.2.1)": "miek.nl.\t3600\tIN\tA\t127.0.2.1", + "miek.nl A 127.0.3.1": "miek.nl.\t3600\tIN\tA\t127.0.3.1", + "_ssh._tcp.local. 60 IN (PTR) stora._ssh._tcp.local.": "_ssh._tcp.local.\t60\tIN\tPTR\tstora._ssh._tcp.local.", + "miek.nl. NS ns.miek.nl": "miek.nl.\t3600\tIN\tNS\tns.miek.nl.", + `(miek.nl.) ( + (IN) + (AAAA) + (::1) )`: "miek.nl.\t3600\tIN\tAAAA\t::1", + `(miek.nl.) ( + (IN) + (AAAA) + (::1))`: "miek.nl.\t3600\tIN\tAAAA\t::1", + "miek.nl. IN AAAA ::2": "miek.nl.\t3600\tIN\tAAAA\t::2", + `((m)(i)ek.(n)l.) (SOA) (soa.) (soa.) ( + 2009032802 ; serial + 21600 ; refresh (6 hours) + 7(2)00 ; retry (2 hours) + 604()800 ; expire (1 week) + 3600 ; minimum (1 hour) + )`: "miek.nl.\t3600\tIN\tSOA\tsoa. soa. 2009032802 21600 7200 604800 3600", + "miek\\.nl. IN A 127.0.0.10": "miek\\.nl.\t3600\tIN\tA\t127.0.0.10", + "miek.nl. IN A 127.0.0.11": "miek.nl.\t3600\tIN\tA\t127.0.0.11", + "miek.nl. A 127.0.0.12": "miek.nl.\t3600\tIN\tA\t127.0.0.12", + `miek.nl. 86400 IN SOA elektron.atoom.net. miekg.atoom.net. ( + 2009032802 ; serial + 21600 ; refresh (6 hours) + 7200 ; retry (2 hours) + 604800 ; expire (1 week) + 3600 ; minimum (1 hour) + )`: "miek.nl.\t86400\tIN\tSOA\telektron.atoom.net. miekg.atoom.net. 2009032802 21600 7200 604800 3600", + } + for i, o := range tests { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error() + "\n\t" + i) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is `%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestParseFailure(t *testing.T) { + tests := []string{"miek.nl. IN A 327.0.0.1", + "miek.nl. IN AAAA ::x", + "miek.nl. IN MX a0 miek.nl.", + "miek.nl aap IN MX mx.miek.nl.", + "miek.nl 200 IN mxx 10 mx.miek.nl.", + "miek.nl. inn MX 10 mx.miek.nl.", + // "miek.nl. IN CNAME ", // actually valid nowadays, zero size rdata + "miek.nl. IN CNAME ..", + "miek.nl. PA MX 10 miek.nl.", + "miek.nl. ) IN MX 10 miek.nl.", + } + + for _, s := range tests { + _, err := NewRR(s) + if err == nil { + t.Logf("should have triggered an error: \"%s\"", s) + t.Fail() + } + } +} + +func TestZoneParsing(t *testing.T) { + // parse_test.db + db := ` +a.example.com. IN A 127.0.0.1 +8db7._openpgpkey.example.com. IN OPENPGPKEY mQCNAzIG +$ORIGIN a.example.com. +test IN A 127.0.0.1 +$ORIGIN b.example.com. +test IN CNAME test.a.example.com. +` + start := time.Now().UnixNano() + to := ParseZone(strings.NewReader(db), "", "parse_test.db") + var i int + for x := range to { + i++ + if x.Error != nil { + t.Logf("%s\n", x.Error) + t.Fail() + continue + } + t.Logf("%s\n", x.RR) + } + delta := time.Now().UnixNano() - start + t.Logf("%d RRs parsed in %.2f s (%.2f RR/s)", i, float32(delta)/1e9, float32(i)/(float32(delta)/1e9)) +} + +func ExampleZone() { + zone := `$ORIGIN . +$TTL 3600 ; 1 hour +name IN SOA a6.nstld.com. hostmaster.nic.name. ( + 203362132 ; serial + 300 ; refresh (5 minutes) + 300 ; retry (5 minutes) + 1209600 ; expire (2 weeks) + 300 ; minimum (5 minutes) + ) +$TTL 10800 ; 3 hours +name. 10800 IN NS name. + IN NS g6.nstld.com. + 7200 NS h6.nstld.com. + 3600 IN NS j6.nstld.com. + IN 3600 NS k6.nstld.com. + NS l6.nstld.com. + NS a6.nstld.com. + NS c6.nstld.com. + NS d6.nstld.com. + NS f6.nstld.com. + NS m6.nstld.com. +( + NS m7.nstld.com. +) +$ORIGIN name. +0-0onlus NS ns7.ehiweb.it. + NS ns8.ehiweb.it. +0-g MX 10 mx01.nic + MX 10 mx02.nic + MX 10 mx03.nic + MX 10 mx04.nic +$ORIGIN 0-g.name +moutamassey NS ns01.yahoodomains.jp. + NS ns02.yahoodomains.jp. +` + to := ParseZone(strings.NewReader(zone), "", "testzone") + for x := range to { + fmt.Printf("%s\n", x.RR) + } + // Output: + // name. 3600 IN SOA a6.nstld.com. hostmaster.nic.name. 203362132 300 300 1209600 300 + // name. 10800 IN NS name. + // name. 10800 IN NS g6.nstld.com. + // name. 7200 IN NS h6.nstld.com. + // name. 3600 IN NS j6.nstld.com. + // name. 3600 IN NS k6.nstld.com. + // name. 10800 IN NS l6.nstld.com. + // name. 10800 IN NS a6.nstld.com. + // name. 10800 IN NS c6.nstld.com. + // name. 10800 IN NS d6.nstld.com. + // name. 10800 IN NS f6.nstld.com. + // name. 10800 IN NS m6.nstld.com. + // name. 10800 IN NS m7.nstld.com. + // 0-0onlus.name. 10800 IN NS ns7.ehiweb.it. + // 0-0onlus.name. 10800 IN NS ns8.ehiweb.it. + // 0-g.name. 10800 IN MX 10 mx01.nic.name. + // 0-g.name. 10800 IN MX 10 mx02.nic.name. + // 0-g.name. 10800 IN MX 10 mx03.nic.name. + // 0-g.name. 10800 IN MX 10 mx04.nic.name. + // moutamassey.0-g.name.name. 10800 IN NS ns01.yahoodomains.jp. + // moutamassey.0-g.name.name. 10800 IN NS ns02.yahoodomains.jp. +} + +func ExampleHIP() { + h := `www.example.com IN HIP ( 2 200100107B1A74DF365639CC39F1D578 + AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p +9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQ +b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D + rvs.example.com. )` + if hip, err := NewRR(h); err == nil { + fmt.Printf("%s\n", hip.String()) + } + // Output: + // www.example.com. 3600 IN HIP 2 200100107B1A74DF365639CC39F1D578 AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQb1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D rvs.example.com. +} + +func TestHIP(t *testing.T) { + h := `www.example.com. IN HIP ( 2 200100107B1A74DF365639CC39F1D578 + AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p +9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQ +b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D + rvs1.example.com. + rvs2.example.com. )` + rr, err := NewRR(h) + if err != nil { + t.Fatalf("failed to parse RR: %s", err) + } + t.Logf("RR: %s", rr) + msg := new(Msg) + msg.Answer = []RR{rr, rr} + bytes, err := msg.Pack() + if err != nil { + t.Fatalf("failed to pack msg: %s", err) + } + if err := msg.Unpack(bytes); err != nil { + t.Fatalf("failed to unpack msg: %s", err) + } + if len(msg.Answer) != 2 { + t.Fatalf("2 answers expected: %V", msg) + } + for i, rr := range msg.Answer { + rr := rr.(*HIP) + t.Logf("RR: %s", rr) + if l := len(rr.RendezvousServers); l != 2 { + t.Fatalf("2 servers expected, only %d in record %d:\n%V", l, i, msg) + } + for j, s := range []string{"rvs1.example.com.", "rvs2.example.com."} { + if rr.RendezvousServers[j] != s { + t.Fatalf("expected server %d of record %d to be %s:\n%V", j, i, s, msg) + } + } + } +} + +func ExampleSOA() { + s := "example.com. 1000 SOA master.example.com. admin.example.com. 1 4294967294 4294967293 4294967295 100" + if soa, err := NewRR(s); err == nil { + fmt.Printf("%s\n", soa.String()) + } + // Output: + // example.com. 1000 IN SOA master.example.com. admin.example.com. 1 4294967294 4294967293 4294967295 100 +} + +func TestLineNumberError(t *testing.T) { + s := "example.com. 1000 SOA master.example.com. admin.example.com. monkey 4294967294 4294967293 4294967295 100" + if _, err := NewRR(s); err != nil { + if err.Error() != "dns: bad SOA zone parameter: \"monkey\" at line: 1:68" { + t.Logf("not expecting this error: " + err.Error()) + t.Fail() + } + } +} + +// Test with no known RR on the line +func TestLineNumberError2(t *testing.T) { + tests := map[string]string{ + "example.com. 1000 SO master.example.com. admin.example.com. 1 4294967294 4294967293 4294967295 100": "dns: expecting RR type or class, not this...: \"SO\" at line: 1:21", + "example.com 1000 IN TALINK a.example.com. b..example.com.": "dns: bad TALINK NextName: \"b..example.com.\" at line: 1:57", + "example.com 1000 IN TALINK ( a.example.com. b..example.com. )": "dns: bad TALINK NextName: \"b..example.com.\" at line: 1:60", + `example.com 1000 IN TALINK ( a.example.com. + bb..example.com. )`: "dns: bad TALINK NextName: \"bb..example.com.\" at line: 2:18", + // This is a bug, it should report an error on line 1, but the new is already processed. + `example.com 1000 IN TALINK ( a.example.com. b...example.com. + )`: "dns: bad TALINK NextName: \"b...example.com.\" at line: 2:1"} + + for in, err := range tests { + _, e := NewRR(in) + if e == nil { + t.Fail() + } else { + if e.Error() != err { + t.Logf("%s\n", in) + t.Logf("error should be %s is %s\n", err, e.Error()) + t.Fail() + } + } + } +} + +// Test if the calculations are correct +func TestRfc1982(t *testing.T) { + // If the current time and the timestamp are more than 68 years apart + // it means the date has wrapped. 0 is 1970 + + // fall in the current 68 year span + strtests := []string{"20120525134203", "19700101000000", "20380119031408"} + for _, v := range strtests { + if x, _ := StringToTime(v); v != TimeToString(x) { + t.Logf("1982 arithmetic string failure %s (%s:%d)", v, TimeToString(x), x) + t.Fail() + } + } + + inttests := map[uint32]string{0: "19700101000000", + 1 << 31: "20380119031408", + 1<<32 - 1: "21060207062815", + } + for i, v := range inttests { + if TimeToString(i) != v { + t.Logf("1982 arithmetic int failure %d:%s (%s)", i, v, TimeToString(i)) + t.Fail() + } + } + + // Future tests, these dates get parsed to a date within the current 136 year span + future := map[string]string{"22680119031408": "20631123173144", + "19010101121212": "20370206184028", + "19210101121212": "20570206184028", + "19500101121212": "20860206184028", + "19700101000000": "19700101000000", + "19690101000000": "21050207062816", + "29210101121212": "21040522212236", + } + for from, to := range future { + x, _ := StringToTime(from) + y := TimeToString(x) + if y != to { + t.Logf("1982 arithmetic future failure %s:%s (%s)", from, to, y) + t.Fail() + } + } +} + +func TestEmpty(t *testing.T) { + for _ = range ParseZone(strings.NewReader(""), "", "") { + t.Logf("should be empty") + t.Fail() + } +} + +func TestLowercaseTokens(t *testing.T) { + var testrecords = []string{ + "example.org. 300 IN a 1.2.3.4", + "example.org. 300 in A 1.2.3.4", + "example.org. 300 in a 1.2.3.4", + "example.org. 300 a 1.2.3.4", + "example.org. 300 A 1.2.3.4", + "example.org. IN a 1.2.3.4", + "example.org. in A 1.2.3.4", + "example.org. in a 1.2.3.4", + "example.org. a 1.2.3.4", + "example.org. A 1.2.3.4", + "example.org. a 1.2.3.4", + "$ORIGIN example.org.\n a 1.2.3.4", + "$Origin example.org.\n a 1.2.3.4", + "$origin example.org.\n a 1.2.3.4", + "example.org. Class1 Type1 1.2.3.4", + } + for _, testrr := range testrecords { + _, err := NewRR(testrr) + if err != nil { + t.Errorf("failed to parse %#v, got %s", testrr, err.Error()) + } + } +} + +func ExampleGenerate() { + // From the manual: http://www.bind9.net/manual/bind/9.3.2/Bv9ARM.ch06.html#id2566761 + zone := "$GENERATE 1-2 0 NS SERVER$.EXAMPLE.\n$GENERATE 1-8 $ CNAME $.0" + to := ParseZone(strings.NewReader(zone), "0.0.192.IN-ADDR.ARPA.", "") + for x := range to { + if x.Error == nil { + fmt.Printf("%s\n", x.RR.String()) + } + } + // Output: + // 0.0.0.192.IN-ADDR.ARPA. 3600 IN NS SERVER1.EXAMPLE. + // 0.0.0.192.IN-ADDR.ARPA. 3600 IN NS SERVER2.EXAMPLE. + // 1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA. + // 2.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 2.0.0.0.192.IN-ADDR.ARPA. + // 3.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 3.0.0.0.192.IN-ADDR.ARPA. + // 4.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 4.0.0.0.192.IN-ADDR.ARPA. + // 5.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 5.0.0.0.192.IN-ADDR.ARPA. + // 6.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 6.0.0.0.192.IN-ADDR.ARPA. + // 7.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 7.0.0.0.192.IN-ADDR.ARPA. + // 8.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 8.0.0.0.192.IN-ADDR.ARPA. +} + +func TestSRVPacking(t *testing.T) { + msg := Msg{} + + things := []string{"1.2.3.4:8484", + "45.45.45.45:8484", + "84.84.84.84:8484", + } + + for i, n := range things { + h, p, err := net.SplitHostPort(n) + if err != nil { + continue + } + port := 8484 + tmp, err := strconv.Atoi(p) + if err == nil { + port = tmp + } + + rr := &SRV{ + Hdr: RR_Header{Name: "somename.", + Rrtype: TypeSRV, + Class: ClassINET, + Ttl: 5}, + Priority: uint16(i), + Weight: 5, + Port: uint16(port), + Target: h + ".", + } + + msg.Answer = append(msg.Answer, rr) + } + + _, err := msg.Pack() + if err != nil { + t.Fatalf("couldn't pack %v\n", msg) + } +} + +func TestParseBackslash(t *testing.T) { + if r, e := NewRR("nul\\000gap.test.globnix.net. 600 IN A 192.0.2.10"); e != nil { + t.Fatalf("could not create RR with \\000 in it") + } else { + t.Logf("parsed %s\n", r.String()) + } + if r, e := NewRR(`nul\000gap.test.globnix.net. 600 IN TXT "Hello\123"`); e != nil { + t.Fatalf("could not create RR with \\000 in it") + } else { + t.Logf("parsed %s\n", r.String()) + } + if r, e := NewRR(`m\ @\ iek.nl. IN 3600 A 127.0.0.1`); e != nil { + t.Fatalf("could not create RR with \\ and \\@ in it") + } else { + t.Logf("parsed %s\n", r.String()) + } +} + +func TestILNP(t *testing.T) { + tests := []string{ + "host1.example.com.\t3600\tIN\tNID\t10 0014:4fff:ff20:ee64", + "host1.example.com.\t3600\tIN\tNID\t20 0015:5fff:ff21:ee65", + "host2.example.com.\t3600\tIN\tNID\t10 0016:6fff:ff22:ee66", + "host1.example.com.\t3600\tIN\tL32\t10 10.1.2.0", + "host1.example.com.\t3600\tIN\tL32\t20 10.1.4.0", + "host2.example.com.\t3600\tIN\tL32\t10 10.1.8.0", + "host1.example.com.\t3600\tIN\tL64\t10 2001:0DB8:1140:1000", + "host1.example.com.\t3600\tIN\tL64\t20 2001:0DB8:2140:2000", + "host2.example.com.\t3600\tIN\tL64\t10 2001:0DB8:4140:4000", + "host1.example.com.\t3600\tIN\tLP\t10 l64-subnet1.example.com.", + "host1.example.com.\t3600\tIN\tLP\t10 l64-subnet2.example.com.", + "host1.example.com.\t3600\tIN\tLP\t20 l32-subnet1.example.com.", + } + for _, t1 := range tests { + r, e := NewRR(t1) + if e != nil { + t.Fatalf("an error occured: %s\n", e.Error()) + } else { + if t1 != r.String() { + t.Fatalf("strings should be equal %s %s", t1, r.String()) + } + } + } +} + +func TestNsapGposEidNimloc(t *testing.T) { + dt := map[string]string{ + "foo.bar.com. IN NSAP 21 47000580ffff000000321099991111222233334444": "foo.bar.com.\t3600\tIN\tNSAP\t21 47000580ffff000000321099991111222233334444", + "host.school.de IN NSAP 17 39276f3100111100002222333344449876": "host.school.de.\t3600\tIN\tNSAP\t17 39276f3100111100002222333344449876", + "444433332222111199990123000000ff. NSAP-PTR foo.bar.com.": "444433332222111199990123000000ff.\t3600\tIN\tNSAP-PTR\tfoo.bar.com.", + "lillee. IN GPOS -32.6882 116.8652 10.0": "lillee.\t3600\tIN\tGPOS\t-32.6882 116.8652 10.0", + "hinault. IN GPOS -22.6882 116.8652 250.0": "hinault.\t3600\tIN\tGPOS\t-22.6882 116.8652 250.0", + "VENERA. IN NIMLOC 75234159EAC457800920": "VENERA.\t3600\tIN\tNIMLOC\t75234159EAC457800920", + "VAXA. IN EID 3141592653589793": "VAXA.\t3600\tIN\tEID\t3141592653589793", + } + for i, o := range dt { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is `%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestPX(t *testing.T) { + dt := map[string]string{ + "*.net2.it. IN PX 10 net2.it. PRMD-net2.ADMD-p400.C-it.": "*.net2.it.\t3600\tIN\tPX\t10 net2.it. PRMD-net2.ADMD-p400.C-it.", + "ab.net2.it. IN PX 10 ab.net2.it. O-ab.PRMD-net2.ADMDb.C-it.": "ab.net2.it.\t3600\tIN\tPX\t10 ab.net2.it. O-ab.PRMD-net2.ADMDb.C-it.", + } + for i, o := range dt { + rr, e := NewRR(i) + if e != nil { + t.Log("failed to parse RR: " + e.Error()) + t.Fail() + continue + } + if rr.String() != o { + t.Logf("`%s' should be equal to\n`%s', but is `%s'\n", i, o, rr.String()) + t.Fail() + } else { + t.Logf("RR is OK: `%s'", rr.String()) + } + } +} + +func TestComment(t *testing.T) { + // Comments we must see + comments := map[string]bool{"; this is comment 1": true, + "; this is comment 4": true, "; this is comment 6": true, + "; this is comment 7": true, "; this is comment 8": true} + zone := ` +foo. IN A 10.0.0.1 ; this is comment 1 +foo. IN A ( + 10.0.0.2 ; this is comment2 +) +; this is comment3 +foo. IN A 10.0.0.3 +foo. IN A ( 10.0.0.4 ); this is comment 4 + +foo. IN A 10.0.0.5 +; this is comment5 + +foo. IN A 10.0.0.6 + +foo. IN DNSKEY 256 3 5 AwEAAb+8l ; this is comment 6 +foo. IN NSEC miek.nl. TXT RRSIG NSEC; this is comment 7 +foo. IN TXT "THIS IS TEXT MAN"; this is comment 8 +` + for x := range ParseZone(strings.NewReader(zone), ".", "") { + if x.Error == nil { + if x.Comment != "" { + if _, ok := comments[x.Comment]; !ok { + t.Logf("wrong comment %s", x.Comment) + t.Fail() + } + } + } + } +} + +func TestEUIxx(t *testing.T) { + tests := map[string]string{ + "host.example. IN EUI48 00-00-5e-90-01-2a": "host.example.\t3600\tIN\tEUI48\t00-00-5e-90-01-2a", + "host.example. IN EUI64 00-00-5e-ef-00-00-00-2a": "host.example.\t3600\tIN\tEUI64\t00-00-5e-ef-00-00-00-2a", + } + for i, o := range tests { + r, e := NewRR(i) + if e != nil { + t.Logf("failed to parse %s: %s\n", i, e.Error()) + t.Fail() + } + if r.String() != o { + t.Logf("want %s, got %s\n", o, r.String()) + t.Fail() + } + } +} + +func TestUserRR(t *testing.T) { + tests := map[string]string{ + "host.example. IN UID 1234": "host.example.\t3600\tIN\tUID\t1234", + "host.example. IN GID 1234556": "host.example.\t3600\tIN\tGID\t1234556", + "host.example. IN UINFO \"Miek Gieben\"": "host.example.\t3600\tIN\tUINFO\t\"Miek Gieben\"", + } + for i, o := range tests { + r, e := NewRR(i) + if e != nil { + t.Logf("failed to parse %s: %s\n", i, e.Error()) + t.Fail() + } + if r.String() != o { + t.Logf("want %s, got %s\n", o, r.String()) + t.Fail() + } + } +} + +func TestTXT(t *testing.T) { + // Test single entry TXT record + rr, err := NewRR(`_raop._tcp.local. 60 IN TXT "single value"`) + if err != nil { + t.Error("failed to parse single value TXT record", err) + } else if rr, ok := rr.(*TXT); !ok { + t.Error("wrong type, record should be of type TXT") + } else { + if len(rr.Txt) != 1 { + t.Error("bad size of TXT value:", len(rr.Txt)) + } else if rr.Txt[0] != "single value" { + t.Error("bad single value") + } + if rr.String() != `_raop._tcp.local. 60 IN TXT "single value"` { + t.Error("bad representation of TXT record:", rr.String()) + } + if rr.len() != 28+1+12 { + t.Error("bad size of serialized record:", rr.len()) + } + } + + // Test multi entries TXT record + rr, err = NewRR(`_raop._tcp.local. 60 IN TXT "a=1" "b=2" "c=3" "d=4"`) + if err != nil { + t.Error("failed to parse multi-values TXT record", err) + } else if rr, ok := rr.(*TXT); !ok { + t.Error("wrong type, record should be of type TXT") + } else { + if len(rr.Txt) != 4 { + t.Error("bad size of TXT multi-value:", len(rr.Txt)) + } else if rr.Txt[0] != "a=1" || rr.Txt[1] != "b=2" || rr.Txt[2] != "c=3" || rr.Txt[3] != "d=4" { + t.Error("bad values in TXT records") + } + if rr.String() != `_raop._tcp.local. 60 IN TXT "a=1" "b=2" "c=3" "d=4"` { + t.Error("bad representation of TXT multi value record:", rr.String()) + } + if rr.len() != 28+1+3+1+3+1+3+1+3 { + t.Error("bad size of serialized multi value record:", rr.len()) + } + } + + // Test empty-string in TXT record + rr, err = NewRR(`_raop._tcp.local. 60 IN TXT ""`) + if err != nil { + t.Error("failed to parse empty-string TXT record", err) + } else if rr, ok := rr.(*TXT); !ok { + t.Error("wrong type, record should be of type TXT") + } else { + if len(rr.Txt) != 1 { + t.Error("bad size of TXT empty-string value:", len(rr.Txt)) + } else if rr.Txt[0] != "" { + t.Error("bad value for empty-string TXT record") + } + if rr.String() != `_raop._tcp.local. 60 IN TXT ""` { + t.Error("bad representation of empty-string TXT record:", rr.String()) + } + if rr.len() != 28+1 { + t.Error("bad size of serialized record:", rr.len()) + } + } +} + +func TestTypeXXXX(t *testing.T) { + _, err := NewRR("example.com IN TYPE1234 \\# 4 aabbccdd") + if err != nil { + t.Logf("failed to parse TYPE1234 RR: %s", err.Error()) + t.Fail() + } + _, err = NewRR("example.com IN TYPE655341 \\# 8 aabbccddaabbccdd") + if err == nil { + t.Logf("this should not work, for TYPE655341") + t.Fail() + } + _, err = NewRR("example.com IN TYPE1 \\# 4 0a000001") + if err == nil { + t.Logf("this should not work") + t.Fail() + } +} + +func TestPTR(t *testing.T) { + _, err := NewRR("144.2.0.192.in-addr.arpa. 900 IN PTR ilouse03146p0\\(.example.com.") + if err != nil { + t.Error("failed to parse ", err.Error()) + } +} + +func TestDigit(t *testing.T) { + tests := map[string]byte{ + "miek\\000.nl. 100 IN TXT \"A\"": 0, + "miek\\001.nl. 100 IN TXT \"A\"": 1, + "miek\\254.nl. 100 IN TXT \"A\"": 254, + "miek\\255.nl. 100 IN TXT \"A\"": 255, + "miek\\256.nl. 100 IN TXT \"A\"": 0, + "miek\\257.nl. 100 IN TXT \"A\"": 1, + "miek\\004.nl. 100 IN TXT \"A\"": 4, + } + for s, i := range tests { + r, e := NewRR(s) + buf := make([]byte, 40) + if e != nil { + t.Fatalf("failed to parse %s\n", e.Error()) + } + PackRR(r, buf, 0, nil, false) + t.Logf("%v\n", buf) + if buf[5] != i { + t.Fatalf("5 pos must be %d, is %d", i, buf[5]) + } + r1, _, _ := UnpackRR(buf, 0) + if r1.Header().Ttl != 100 { + t.Fatalf("TTL should %d, is %d", 100, r1.Header().Ttl) + } + } +} + +func TestParseRRSIGTimestamp(t *testing.T) { + tests := map[string]bool{ + `miek.nl. IN RRSIG SOA 8 2 43200 20140210031301 20140111031301 12051 miek.nl. MVZUyrYwq0iZhMFDDnVXD2BvuNiUJjSYlJAgzyAE6CF875BMvvZa+Sb0 RlSCL7WODQSQHhCx/fegHhVVF+Iz8N8kOLrmXD1+jO3Bm6Prl5UhcsPx WTBsg/kmxbp8sR1kvH4oZJtVfakG3iDerrxNaf0sQwhZzyfJQAqpC7pcBoc=`: true, + `miek.nl. IN RRSIG SOA 8 2 43200 315565800 4102477800 12051 miek.nl. MVZUyrYwq0iZhMFDDnVXD2BvuNiUJjSYlJAgzyAE6CF875BMvvZa+Sb0 RlSCL7WODQSQHhCx/fegHhVVF+Iz8N8kOLrmXD1+jO3Bm6Prl5UhcsPx WTBsg/kmxbp8sR1kvH4oZJtVfakG3iDerrxNaf0sQwhZzyfJQAqpC7pcBoc=`: true, + } + for r, _ := range tests { + _, e := NewRR(r) + if e != nil { + t.Fail() + t.Logf("%s\n", e.Error()) + } + } +} + +func TestTxtEqual(t *testing.T) { + rr1 := new(TXT) + rr1.Hdr = RR_Header{Name: ".", Rrtype: TypeTXT, Class: ClassINET, Ttl: 0} + rr1.Txt = []string{"a\"a", "\"", "b"} + rr2, _ := NewRR(rr1.String()) + if rr1.String() != rr2.String() { + t.Logf("these two TXT records should match") + t.Logf("\n%s\n%s\n", rr1.String(), rr2.String()) + t.Fail() // This is not an error, but keep this test. + } + t.Logf("\n%s\n%s\n", rr1.String(), rr2.String()) +} + +func TestTxtLong(t *testing.T) { + rr1 := new(TXT) + rr1.Hdr = RR_Header{Name: ".", Rrtype: TypeTXT, Class: ClassINET, Ttl: 0} + // Make a long txt record, this breaks when sending the packet, + // but not earlier. + rr1.Txt = []string{"start-"} + for i := 0; i < 200; i++ { + rr1.Txt[0] += "start-" + } + str := rr1.String() + if len(str) < len(rr1.Txt[0]) { + t.Logf("string conversion should work") + t.Fail() + } +} + +// Basically, don't crash. +func TestMalformedPackets(t *testing.T) { + var packets = []string{ + "0021641c0000000100000000000078787878787878787878787303636f6d0000100001", + } + + // com = 63 6f 6d + for _, packet := range packets { + data, _ := hex.DecodeString(packet) + // for _, v := range data { + // t.Logf("%s ", string(v)) + // } + var msg Msg + msg.Unpack(data) + // println(msg.String()) + } +} + +func TestDynamicUpdateParsing(t *testing.T) { + prefix := "example.com. IN " + for _, typ := range TypeToString { + if typ == "CAA" || typ == "OPT" || typ == "AXFR" || typ == "IXFR" || typ == "ANY" || typ == "TKEY" || + typ == "TSIG" || typ == "ISDN" || typ == "UNSPEC" || typ == "NULL" || typ == "ATMA" { + continue + } + r, e := NewRR(prefix + typ) + if e != nil { + t.Log("failure to parse: " + prefix + typ) + t.Fail() + } else { + t.Logf("parsed: %s", r.String()) + } + } +} + +type algorithm struct { + name uint8 + bits int +} + +func TestNewPrivateKeyECDSA(t *testing.T) { + algorithms := []algorithm{ + algorithm{ECDSAP256SHA256, 256}, + algorithm{ECDSAP384SHA384, 384}, + algorithm{RSASHA1, 1024}, + algorithm{RSASHA256, 2048}, + // algorithm{DSA, 1024}, // TODO: STILL BROKEN! + } + + for _, algo := range algorithms { + key := new(DNSKEY) + key.Hdr.Rrtype = TypeDNSKEY + key.Hdr.Name = "miek.nl." + key.Hdr.Class = ClassINET + key.Hdr.Ttl = 14400 + key.Flags = 256 + key.Protocol = 3 + key.Algorithm = algo.name + privkey, err := key.Generate(algo.bits) + if err != nil { + t.Fatal(err.Error()) + } + + newPrivKey, err := key.NewPrivateKey(key.PrivateKeyString(privkey)) + if err != nil { + t.Fatal(err.Error()) + } + + switch newPrivKey := newPrivKey.(type) { + case *rsa.PrivateKey: + newPrivKey.Precompute() + } + + if !reflect.DeepEqual(privkey, newPrivKey) { + t.Errorf("[%v] Private keys differ:\n%#v\n%#v\n", AlgorithmToString[algo.name], privkey, newPrivKey) + } + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/privaterr.go b/Godeps/_workspace/src/github.com/miekg/dns/privaterr.go new file mode 100644 index 0000000000..225713743c --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/privaterr.go @@ -0,0 +1,122 @@ +/* +PRIVATE RR + +RFC 6895 sets aside a range of type codes for private use. This range +is 65,280 - 65,534 (0xFF00 - 0xFFFE). When experimenting with new Resource Records these +can be used, before requesting an official type code from IANA. +*/ +package dns + +import ( + "fmt" + "strings" +) + +// PrivateRdata is an interface used for implementing "Private Use" RR types, see +// RFC 6895. This allows one to experiment with new RR types, without requesting an +// official type code. Also see dns.PrivateHandle and dns.PrivateHandleRemove. +type PrivateRdata interface { + // String returns the text presentaton of the Rdata of the Private RR. + String() string + // Parse parses the Rdata of the private RR. + Parse([]string) error + // Pack is used when packing a private RR into a buffer. + Pack([]byte) (int, error) + // Unpack is used when unpacking a private RR from a buffer. + // TODO(miek): diff. signature than Pack, see edns0.go for instance. + Unpack([]byte) (int, error) + // Copy copies the Rdata. + Copy(PrivateRdata) error + // Len returns the length in octets of the Rdata. + Len() int +} + +// PrivateRR represents an RR that uses a PrivateRdata user-defined type. +// It mocks normal RRs and implements dns.RR interface. +type PrivateRR struct { + Hdr RR_Header + Data PrivateRdata +} + +func mkPrivateRR(rrtype uint16) *PrivateRR { + // Panics if RR is not an instance of PrivateRR. + rrfunc, ok := typeToRR[rrtype] + if !ok { + panic(fmt.Sprintf("dns: invalid operation with Private RR type %d", rrtype)) + } + + anyrr := rrfunc() + switch rr := anyrr.(type) { + case *PrivateRR: + return rr + } + panic(fmt.Sprintf("dns: RR is not a PrivateRR, typeToRR[%d] generator returned %T", rrtype, anyrr)) +} + +func (r *PrivateRR) Header() *RR_Header { return &r.Hdr } +func (r *PrivateRR) String() string { return r.Hdr.String() + r.Data.String() } + +// Private len and copy parts to satisfy RR interface. +func (r *PrivateRR) len() int { return r.Hdr.len() + r.Data.Len() } +func (r *PrivateRR) copy() RR { + // make new RR like this: + rr := mkPrivateRR(r.Hdr.Rrtype) + newh := r.Hdr.copyHeader() + rr.Hdr = *newh + + err := r.Data.Copy(rr.Data) + if err != nil { + panic("dns: got value that could not be used to copy Private rdata") + } + return rr +} + +// PrivateHandle registers a private resource record type. It requires +// string and numeric representation of private RR type and generator function as argument. +func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) { + rtypestr = strings.ToUpper(rtypestr) + + typeToRR[rtype] = func() RR { return &PrivateRR{RR_Header{}, generator()} } + TypeToString[rtype] = rtypestr + StringToType[rtypestr] = rtype + + setPrivateRR := func(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := mkPrivateRR(h.Rrtype) + rr.Hdr = h + + var l lex + text := make([]string, 0, 2) // could be 0..N elements, median is probably 1 + FETCH: + for { + // TODO(miek): we could also be returning _QUOTE, this might or might not + // be an issue (basically parsing TXT becomes hard) + switch l = <-c; l.value { + case _NEWLINE, _EOF: + break FETCH + case _STRING: + text = append(text, l.token) + } + } + + err := rr.Data.Parse(text) + if err != nil { + return nil, &ParseError{f, err.Error(), l}, "" + } + + return rr, nil, "" + } + + typeToparserFunc[rtype] = parserFunc{setPrivateRR, true} +} + +// PrivateHandleRemove removes defenitions required to support private RR type. +func PrivateHandleRemove(rtype uint16) { + rtypestr, ok := TypeToString[rtype] + if ok { + delete(typeToRR, rtype) + delete(TypeToString, rtype) + delete(typeToparserFunc, rtype) + delete(StringToType, rtypestr) + } + return +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/privaterr_test.go b/Godeps/_workspace/src/github.com/miekg/dns/privaterr_test.go new file mode 100644 index 0000000000..6f1dff89c6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/privaterr_test.go @@ -0,0 +1,169 @@ +package dns_test + +import ( + "github.com/miekg/dns" + "strings" + "testing" +) + +const TypeISBN uint16 = 0x0F01 + +// A crazy new RR type :) +type ISBN struct { + x string // rdata with 10 or 13 numbers, dashes or spaces allowed +} + +func NewISBN() dns.PrivateRdata { return &ISBN{""} } + +func (rd *ISBN) Len() int { return len([]byte(rd.x)) } +func (rd *ISBN) String() string { return rd.x } + +func (rd *ISBN) Parse(txt []string) error { + rd.x = strings.TrimSpace(strings.Join(txt, " ")) + return nil +} + +func (rd *ISBN) Pack(buf []byte) (int, error) { + b := []byte(rd.x) + n := copy(buf, b) + if n != len(b) { + return n, dns.ErrBuf + } + return n, nil +} + +func (rd *ISBN) Unpack(buf []byte) (int, error) { + rd.x = string(buf) + return len(buf), nil +} + +func (rd *ISBN) Copy(dest dns.PrivateRdata) error { + isbn, ok := dest.(*ISBN) + if !ok { + return dns.ErrRdata + } + isbn.x = rd.x + return nil +} + +var testrecord = strings.Join([]string{"example.org.", "3600", "IN", "ISBN", "12-3 456789-0-123"}, "\t") + +func TestPrivateText(t *testing.T) { + dns.PrivateHandle("ISBN", TypeISBN, NewISBN) + defer dns.PrivateHandleRemove(TypeISBN) + + rr, err := dns.NewRR(testrecord) + if err != nil { + t.Fatal(err) + } + if rr.String() != testrecord { + t.Errorf("record string representation did not match original %#v != %#v", rr.String(), testrecord) + } else { + t.Log(rr.String()) + } +} + +func TestPrivateByteSlice(t *testing.T) { + dns.PrivateHandle("ISBN", TypeISBN, NewISBN) + defer dns.PrivateHandleRemove(TypeISBN) + + rr, err := dns.NewRR(testrecord) + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, 100) + off, err := dns.PackRR(rr, buf, 0, nil, false) + if err != nil { + t.Errorf("got error packing ISBN: %s", err) + } + + custrr := rr.(*dns.PrivateRR) + if ln := custrr.Data.Len() + len(custrr.Header().Name) + 11; ln != off { + t.Errorf("offset is not matching to length of Private RR: %d!=%d", off, ln) + } + + rr1, off1, err := dns.UnpackRR(buf[:off], 0) + if err != nil { + t.Errorf("got error unpacking ISBN: %s", err) + } + + if off1 != off { + t.Errorf("Offset after unpacking differs: %d != %d", off1, off) + } + + if rr1.String() != testrecord { + t.Errorf("Record string representation did not match original %#v != %#v", rr1.String(), testrecord) + } else { + t.Log(rr1.String()) + } +} + +const TypeVERSION uint16 = 0x0F02 + +type VERSION struct { + x string +} + +func NewVersion() dns.PrivateRdata { return &VERSION{""} } + +func (rd *VERSION) String() string { return rd.x } +func (rd *VERSION) Parse(txt []string) error { + rd.x = strings.TrimSpace(strings.Join(txt, " ")) + return nil +} + +func (rd *VERSION) Pack(buf []byte) (int, error) { + b := []byte(rd.x) + n := copy(buf, b) + if n != len(b) { + return n, dns.ErrBuf + } + return n, nil +} + +func (rd *VERSION) Unpack(buf []byte) (int, error) { + rd.x = string(buf) + return len(buf), nil +} + +func (rd *VERSION) Copy(dest dns.PrivateRdata) error { + isbn, ok := dest.(*VERSION) + if !ok { + return dns.ErrRdata + } + isbn.x = rd.x + return nil +} + +func (rd *VERSION) Len() int { + return len([]byte(rd.x)) +} + +var smallzone = `$ORIGIN example.org. +@ SOA sns.dns.icann.org. noc.dns.icann.org. ( + 2014091518 7200 3600 1209600 3600 +) + A 1.2.3.4 +ok ISBN 1231-92110-12 +go VERSION ( + 1.3.1 ; comment +) +www ISBN 1231-92110-16 +* CNAME @ +` + +func TestPrivateZoneParser(t *testing.T) { + dns.PrivateHandle("ISBN", TypeISBN, NewISBN) + dns.PrivateHandle("VERSION", TypeVERSION, NewVersion) + defer dns.PrivateHandleRemove(TypeISBN) + defer dns.PrivateHandleRemove(TypeVERSION) + + r := strings.NewReader(smallzone) + for x := range dns.ParseZone(r, ".", "") { + if err := x.Error; err != nil { + t.Fatal(err) + } + t.Log(x.RR) + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/rawmsg.go b/Godeps/_workspace/src/github.com/miekg/dns/rawmsg.go new file mode 100644 index 0000000000..f138b7761d --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/rawmsg.go @@ -0,0 +1,95 @@ +package dns + +// These raw* functions do not use reflection, they directly set the values +// in the buffer. There are faster than their reflection counterparts. + +// RawSetId sets the message id in buf. +func rawSetId(msg []byte, i uint16) bool { + if len(msg) < 2 { + return false + } + msg[0], msg[1] = packUint16(i) + return true +} + +// rawSetQuestionLen sets the length of the question section. +func rawSetQuestionLen(msg []byte, i uint16) bool { + if len(msg) < 6 { + return false + } + msg[4], msg[5] = packUint16(i) + return true +} + +// rawSetAnswerLen sets the lenght of the answer section. +func rawSetAnswerLen(msg []byte, i uint16) bool { + if len(msg) < 8 { + return false + } + msg[6], msg[7] = packUint16(i) + return true +} + +// rawSetsNsLen sets the lenght of the authority section. +func rawSetNsLen(msg []byte, i uint16) bool { + if len(msg) < 10 { + return false + } + msg[8], msg[9] = packUint16(i) + return true +} + +// rawSetExtraLen sets the lenght of the additional section. +func rawSetExtraLen(msg []byte, i uint16) bool { + if len(msg) < 12 { + return false + } + msg[10], msg[11] = packUint16(i) + return true +} + +// rawSetRdlength sets the rdlength in the header of +// the RR. The offset 'off' must be positioned at the +// start of the header of the RR, 'end' must be the +// end of the RR. +func rawSetRdlength(msg []byte, off, end int) bool { + l := len(msg) +Loop: + for { + if off+1 > l { + return false + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // End of the domainname + break Loop + } + if off+c > l { + return false + } + off += c + + case 0xC0: + // pointer, next byte included, ends domainname + off++ + break Loop + } + } + // The domainname has been seen, we at the start of the fixed part in the header. + // Type is 2 bytes, class is 2 bytes, ttl 4 and then 2 bytes for the length. + off += 2 + 2 + 4 + if off+2 > l { + return false + } + //off+1 is the end of the header, 'end' is the end of the rr + //so 'end' - 'off+2' is the length of the rdata + rdatalen := end - (off + 2) + if rdatalen > 0xFFFF { + return false + } + msg[off], msg[off+1] = packUint16(uint16(rdatalen)) + return true +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/scanner.go b/Godeps/_workspace/src/github.com/miekg/dns/scanner.go new file mode 100644 index 0000000000..c29bc2f388 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/scanner.go @@ -0,0 +1,43 @@ +package dns + +// Implement a simple scanner, return a byte stream from an io reader. + +import ( + "bufio" + "io" + "text/scanner" +) + +type scan struct { + src *bufio.Reader + position scanner.Position + eof bool // Have we just seen a eof +} + +func scanInit(r io.Reader) *scan { + s := new(scan) + s.src = bufio.NewReader(r) + s.position.Line = 1 + return s +} + +// tokenText returns the next byte from the input +func (s *scan) tokenText() (byte, error) { + c, err := s.src.ReadByte() + if err != nil { + return c, err + } + // delay the newline handling until the next token is delivered, + // fixes off-by-one errors when reporting a parse error. + if s.eof == true { + s.position.Line++ + s.position.Column = 0 + s.eof = false + } + if c == '\n' { + s.eof = true + return c, nil + } + s.position.Column++ + return c, nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/server.go b/Godeps/_workspace/src/github.com/miekg/dns/server.go new file mode 100644 index 0000000000..cf3353bd76 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/server.go @@ -0,0 +1,612 @@ +// DNS server implementation. + +package dns + +import ( + "bytes" + "io" + "net" + "sync" + "time" +) + +type Handler interface { + ServeDNS(w ResponseWriter, r *Msg) +} + +// A ResponseWriter interface is used by an DNS handler to +// construct an DNS response. +type ResponseWriter interface { + // LocalAddr returns the net.Addr of the server + LocalAddr() net.Addr + // RemoteAddr returns the net.Addr of the client that sent the current request. + RemoteAddr() net.Addr + // WriteMsg writes a reply back to the client. + WriteMsg(*Msg) error + // Write writes a raw buffer back to the client. + Write([]byte) (int, error) + // Close closes the connection. + Close() error + // TsigStatus returns the status of the Tsig. + TsigStatus() error + // TsigTimersOnly sets the tsig timers only boolean. + TsigTimersOnly(bool) + // Hijack lets the caller take over the connection. + // After a call to Hijack(), the DNS package will not do anything with the connection. + Hijack() +} + +type response struct { + hijacked bool // connection has been hijacked by handler + tsigStatus error + tsigTimersOnly bool + tsigRequestMAC string + tsigSecret map[string]string // the tsig secrets + udp *net.UDPConn // i/o connection if UDP was used + tcp *net.TCPConn // i/o connection if TCP was used + udpSession *sessionUDP // oob data to get egress interface right + remoteAddr net.Addr // address of the client +} + +// ServeMux is an DNS request multiplexer. It matches the +// zone name of each incoming request against a list of +// registered patterns add calls the handler for the pattern +// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning +// that queries for the DS record are redirected to the parent zone (if that +// is also registered), otherwise the child gets the query. +// ServeMux is also safe for concurrent access from multiple goroutines. +type ServeMux struct { + z map[string]Handler + m *sync.RWMutex +} + +// NewServeMux allocates and returns a new ServeMux. +func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} } + +// DefaultServeMux is the default ServeMux used by Serve. +var DefaultServeMux = NewServeMux() + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as DNS handlers. If f is a function +// with the appropriate signature, HandlerFunc(f) is a +// Handler object that calls f. +type HandlerFunc func(ResponseWriter, *Msg) + +// ServerDNS calls f(w, r) +func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) { + f(w, r) +} + +// FailedHandler returns a HandlerFunc +// returns SERVFAIL for every request it gets. +func HandleFailed(w ResponseWriter, r *Msg) { + m := new(Msg) + m.SetRcode(r, RcodeServerFailure) + // does not matter if this write fails + w.WriteMsg(m) +} + +func failedHandler() Handler { return HandlerFunc(HandleFailed) } + +// ListenAndServe Starts a server on addresss and network speficied. Invoke handler +// for incoming queries. +func ListenAndServe(addr string, network string, handler Handler) error { + server := &Server{Addr: addr, Net: network, Handler: handler} + return server.ListenAndServe() +} + +// ActivateAndServe activates a server with a listener from systemd, +// l and p should not both be non-nil. +// If both l and p are not nil only p will be used. +// Invoke handler for incoming queries. +func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error { + server := &Server{Listener: l, PacketConn: p, Handler: handler} + return server.ActivateAndServe() +} + +func (mux *ServeMux) match(q string, t uint16) Handler { + mux.m.RLock() + defer mux.m.RUnlock() + var handler Handler + b := make([]byte, len(q)) // worst case, one label of length q + off := 0 + end := false + for { + l := len(q[off:]) + for i := 0; i < l; i++ { + b[i] = q[off+i] + if b[i] >= 'A' && b[i] <= 'Z' { + b[i] |= ('a' - 'A') + } + } + if h, ok := mux.z[string(b[:l])]; ok { // 'causes garbage, might want to change the map key + if t != TypeDS { + return h + } else { + // Continue for DS to see if we have a parent too, if so delegeate to the parent + handler = h + } + } + off, end = NextLabel(q, off) + if end { + break + } + } + // Wildcard match, if we have found nothing try the root zone as a last resort. + if h, ok := mux.z["."]; ok { + return h + } + return handler +} + +// Handle adds a handler to the ServeMux for pattern. +func (mux *ServeMux) Handle(pattern string, handler Handler) { + if pattern == "" { + panic("dns: invalid pattern " + pattern) + } + mux.m.Lock() + mux.z[Fqdn(pattern)] = handler + mux.m.Unlock() +} + +// Handle adds a handler to the ServeMux for pattern. +func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { + mux.Handle(pattern, HandlerFunc(handler)) +} + +// HandleRemove deregistrars the handler specific for pattern from the ServeMux. +func (mux *ServeMux) HandleRemove(pattern string) { + if pattern == "" { + panic("dns: invalid pattern " + pattern) + } + // don't need a mutex here, because deleting is OK, even if the + // entry is note there. + delete(mux.z, Fqdn(pattern)) +} + +// ServeDNS dispatches the request to the handler whose +// pattern most closely matches the request message. If DefaultServeMux +// is used the correct thing for DS queries is done: a possible parent +// is sought. +// If no handler is found a standard SERVFAIL message is returned +// If the request message does not have exactly one question in the +// question section a SERVFAIL is returned. +func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) { + var h Handler + if len(request.Question) != 1 { + h = failedHandler() + } else { + if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil { + h = failedHandler() + } + } + h.ServeDNS(w, request) +} + +// Handle registers the handler with the given pattern +// in the DefaultServeMux. The documentation for +// ServeMux explains how patterns are matched. +func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } + +// HandleRemove deregisters the handle with the given pattern +// in the DefaultServeMux. +func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) } + +// HandleFunc registers the handler function with the given pattern +// in the DefaultServeMux. +func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { + DefaultServeMux.HandleFunc(pattern, handler) +} + +// A Server defines parameters for running an DNS server. +type Server struct { + // Address to listen on, ":dns" if empty. + Addr string + // if "tcp" it will invoke a TCP listener, otherwise an UDP one. + Net string + // TCP Listener to use, this is to aid in systemd's socket activation. + Listener net.Listener + // UDP "Listener" to use, this is to aid in systemd's socket activation. + PacketConn net.PacketConn + // Handler to invoke, dns.DefaultServeMux if nil. + Handler Handler + // Default buffer size to use to read incoming UDP messages. If not set + // it defaults to MinMsgSize (512 B). + UDPSize int + // The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second. + ReadTimeout time.Duration + // The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second. + WriteTimeout time.Duration + // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966). + IdleTimeout func() time.Duration + // Secret(s) for Tsig map[]. + TsigSecret map[string]string + + // For graceful shutdown. + stopUDP chan bool + stopTCP chan bool + wgUDP sync.WaitGroup + wgTCP sync.WaitGroup + + // make start/shutdown not racy + lock sync.Mutex + started bool +} + +// ListenAndServe starts a nameserver on the configured address in *Server. +func (srv *Server) ListenAndServe() error { + srv.lock.Lock() + if srv.started { + return &Error{err: "server already started"} + } + srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool) + srv.started = true + srv.lock.Unlock() + addr := srv.Addr + if addr == "" { + addr = ":domain" + } + if srv.UDPSize == 0 { + srv.UDPSize = MinMsgSize + } + switch srv.Net { + case "tcp", "tcp4", "tcp6": + a, e := net.ResolveTCPAddr(srv.Net, addr) + if e != nil { + return e + } + l, e := net.ListenTCP(srv.Net, a) + if e != nil { + return e + } + return srv.serveTCP(l) + case "udp", "udp4", "udp6": + a, e := net.ResolveUDPAddr(srv.Net, addr) + if e != nil { + return e + } + l, e := net.ListenUDP(srv.Net, a) + if e != nil { + return e + } + if e := setUDPSocketOptions(l); e != nil { + return e + } + return srv.serveUDP(l) + } + return &Error{err: "bad network"} +} + +// ActivateAndServe starts a nameserver with the PacketConn or Listener +// configured in *Server. Its main use is to start a server from systemd. +func (srv *Server) ActivateAndServe() error { + srv.lock.Lock() + if srv.started { + return &Error{err: "server already started"} + } + srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool) + srv.started = true + srv.lock.Unlock() + if srv.UDPSize == 0 { + srv.UDPSize = MinMsgSize + } + if srv.PacketConn != nil { + if srv.UDPSize == 0 { + srv.UDPSize = MinMsgSize + } + if t, ok := srv.PacketConn.(*net.UDPConn); ok { + if e := setUDPSocketOptions(t); e != nil { + return e + } + return srv.serveUDP(t) + } + } + if srv.Listener != nil { + if t, ok := srv.Listener.(*net.TCPListener); ok { + return srv.serveTCP(t) + } + } + return &Error{err: "bad listeners"} +} + +// Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and +// ActivateAndServe will return. All in progress queries are completed before the server +// is taken down. If the Shutdown is taking longer than the reading timeout and error +// is returned. +func (srv *Server) Shutdown() error { + srv.lock.Lock() + if !srv.started { + return &Error{err: "server not started"} + } + srv.started = false + srv.lock.Unlock() + net, addr := srv.Net, srv.Addr + switch { + case srv.Listener != nil: + a := srv.Listener.Addr() + net, addr = a.Network(), a.String() + case srv.PacketConn != nil: + a := srv.PacketConn.LocalAddr() + net, addr = a.Network(), a.String() + } + + fin := make(chan bool) + switch net { + case "tcp", "tcp4", "tcp6": + go func() { + srv.stopTCP <- true + srv.wgTCP.Wait() + fin <- true + }() + + case "udp", "udp4", "udp6": + go func() { + srv.stopUDP <- true + srv.wgUDP.Wait() + fin <- true + }() + } + + c := &Client{Net: net} + go c.Exchange(new(Msg), addr) // extra query to help ReadXXX loop to pass + + select { + case <-time.After(srv.getReadTimeout()): + return &Error{err: "server shutdown is pending"} + case <-fin: + return nil + } +} + +// getReadTimeout is a helper func to use system timeout if server did not intend to change it. +func (srv *Server) getReadTimeout() time.Duration { + rtimeout := dnsTimeout + if srv.ReadTimeout != 0 { + rtimeout = srv.ReadTimeout + } + return rtimeout +} + +// serveTCP starts a TCP listener for the server. +// Each request is handled in a seperate goroutine. +func (srv *Server) serveTCP(l *net.TCPListener) error { + defer l.Close() + handler := srv.Handler + if handler == nil { + handler = DefaultServeMux + } + rtimeout := srv.getReadTimeout() + // deadline is not used here + for { + rw, e := l.AcceptTCP() + if e != nil { + continue + } + m, e := srv.readTCP(rw, rtimeout) + select { + case <-srv.stopTCP: + return nil + default: + } + if e != nil { + continue + } + srv.wgTCP.Add(1) + go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) + } + panic("dns: not reached") +} + +// serveUDP starts a UDP listener for the server. +// Each request is handled in a seperate goroutine. +func (srv *Server) serveUDP(l *net.UDPConn) error { + defer l.Close() + + handler := srv.Handler + if handler == nil { + handler = DefaultServeMux + } + rtimeout := srv.getReadTimeout() + // deadline is not used here + for { + m, s, e := srv.readUDP(l, rtimeout) + select { + case <-srv.stopUDP: + return nil + default: + } + if e != nil { + continue + } + srv.wgUDP.Add(1) + go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) + } + panic("dns: not reached") +} + +// Serve a new connection. +func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} + q := 0 + defer func() { + if u != nil { + srv.wgUDP.Done() + } + if t != nil { + srv.wgTCP.Done() + } + }() +Redo: + req := new(Msg) + err := req.Unpack(m) + if err != nil { // Send a FormatError back + x := new(Msg) + x.SetRcodeFormatError(req) + w.WriteMsg(x) + goto Exit + } + + w.tsigStatus = nil + if w.tsigSecret != nil { + if t := req.IsTsig(); t != nil { + secret := t.Hdr.Name + if _, ok := w.tsigSecret[secret]; !ok { + w.tsigStatus = ErrKeyAlg + } + w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) + w.tsigTimersOnly = false + w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC + } + } + h.ServeDNS(w, req) // Writes back to the client + +Exit: + if w.hijacked { + return // client calls Close() + } + if u != nil { // UDP, "close" and return + w.Close() + return + } + idleTimeout := tcpIdleTimeout + if srv.IdleTimeout != nil { + idleTimeout = srv.IdleTimeout() + } + m, e := srv.readTCP(w.tcp, idleTimeout) + if e == nil { + q++ + // TODO(miek): make this number configurable? + if q > 128 { // close socket after this many queries + w.Close() + return + } + goto Redo + } + w.Close() + return +} + +func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) { + conn.SetReadDeadline(time.Now().Add(timeout)) + l := make([]byte, 2) + n, err := conn.Read(l) + if err != nil || n != 2 { + if err != nil { + return nil, err + } + return nil, ErrShortRead + } + length, _ := unpackUint16(l, 0) + if length == 0 { + return nil, ErrShortRead + } + m := make([]byte, int(length)) + n, err = conn.Read(m[:int(length)]) + if err != nil || n == 0 { + if err != nil { + return nil, err + } + return nil, ErrShortRead + } + i := n + for i < int(length) { + j, err := conn.Read(m[i:int(length)]) + if err != nil { + return nil, err + } + i += j + } + n = i + m = m[:n] + return m, nil +} + +func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *sessionUDP, error) { + conn.SetReadDeadline(time.Now().Add(timeout)) + m := make([]byte, srv.UDPSize) + n, s, e := readFromSessionUDP(conn, m) + if e != nil || n == 0 { + if e != nil { + return nil, nil, e + } + return nil, nil, ErrShortRead + } + m = m[:n] + return m, s, nil +} + +// WriteMsg implements the ResponseWriter.WriteMsg method. +func (w *response) WriteMsg(m *Msg) (err error) { + var data []byte + if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) + if t := m.IsTsig(); t != nil { + data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if err != nil { + return err + } + _, err = w.Write(data) + return err + } + } + data, err = m.Pack() + if err != nil { + return err + } + _, err = w.Write(data) + return err +} + +// Write implements the ResponseWriter.Write method. +func (w *response) Write(m []byte) (int, error) { + switch { + case w.udp != nil: + n, err := writeToSessionUDP(w.udp, m, w.udpSession) + return n, err + case w.tcp != nil: + lm := len(m) + if lm < 2 { + return 0, io.ErrShortBuffer + } + if lm > MaxMsgSize { + return 0, &Error{err: "message too large"} + } + l := make([]byte, 2, 2+lm) + l[0], l[1] = packUint16(uint16(lm)) + m = append(l, m...) + + n, err := io.Copy(w.tcp, bytes.NewReader(m)) + return int(n), err + } + panic("not reached") +} + +// LocalAddr implements the ResponseWriter.LocalAddr method. +func (w *response) LocalAddr() net.Addr { + if w.tcp != nil { + return w.tcp.LocalAddr() + } + return w.udp.LocalAddr() +} + +// RemoteAddr implements the ResponseWriter.RemoteAddr method. +func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } + +// TsigStatus implements the ResponseWriter.TsigStatus method. +func (w *response) TsigStatus() error { return w.tsigStatus } + +// TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method. +func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b } + +// Hijack implements the ResponseWriter.Hijack method. +func (w *response) Hijack() { w.hijacked = true } + +// Close implements the ResponseWriter.Close method +func (w *response) Close() error { + // Can't close the udp conn, as that is actually the listener. + if w.tcp != nil { + e := w.tcp.Close() + w.tcp = nil + return e + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/server_test.go b/Godeps/_workspace/src/github.com/miekg/dns/server_test.go new file mode 100644 index 0000000000..b743013897 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/server_test.go @@ -0,0 +1,331 @@ +package dns + +import ( + "fmt" + "net" + "runtime" + "sync" + "testing" +) + +func HelloServer(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} + w.WriteMsg(m) +} + +func AnotherHelloServer(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello example"}} + w.WriteMsg(m) +} + +func RunLocalUDPServer(laddr string) (*Server, string, error) { + pc, err := net.ListenPacket("udp", laddr) + if err != nil { + return nil, "", err + } + server := &Server{PacketConn: pc} + go func() { + server.ActivateAndServe() + pc.Close() + }() + return server, pc.LocalAddr().String(), nil +} + +func RunLocalTCPServer(laddr string) (*Server, string, error) { + l, err := net.Listen("tcp", laddr) + if err != nil { + return nil, "", err + } + server := &Server{Listener: l} + go func() { + server.ActivateAndServe() + l.Close() + }() + return server, l.Addr().String(), nil +} + +func TestServing(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + HandleFunc("example.com.", AnotherHelloServer) + defer HandleRemove("miek.nl.") + defer HandleRemove("example.com.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + c := new(Client) + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + r, _, err := c.Exchange(m, addrstr) + if err != nil { + t.Log("failed to exchange miek.nl", err) + t.Fatal() + } + txt := r.Extra[0].(*TXT).Txt[0] + if txt != "Hello world" { + t.Log("Unexpected result for miek.nl", txt, "!= Hello world") + t.Fail() + } + + m.SetQuestion("example.com.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Log("failed to exchange example.com", err) + t.Fatal() + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Log("Unexpected result for example.com", txt, "!= Hello example") + t.Fail() + } + + // Test Mixes cased as noticed by Ask. + m.SetQuestion("eXaMplE.cOm.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Log("failed to exchange eXaMplE.cOm", err) + t.Fail() + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Log("Unexpected result for example.com", txt, "!= Hello example") + t.Fail() + } +} + +func BenchmarkServe(b *testing.B) { + b.StopTimer() + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + a := runtime.GOMAXPROCS(4) + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + b.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + c := new(Client) + m := new(Msg) + m.SetQuestion("miek.nl", TypeSOA) + + b.StartTimer() + for i := 0; i < b.N; i++ { + c.Exchange(m, addrstr) + } + runtime.GOMAXPROCS(a) +} + +func benchmarkServe6(b *testing.B) { + b.StopTimer() + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + a := runtime.GOMAXPROCS(4) + s, addrstr, err := RunLocalUDPServer("[::1]:0") + if err != nil { + b.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + c := new(Client) + m := new(Msg) + m.SetQuestion("miek.nl", TypeSOA) + + b.StartTimer() + for i := 0; i < b.N; i++ { + c.Exchange(m, addrstr) + } + runtime.GOMAXPROCS(a) +} + +func HelloServerCompress(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} + m.Compress = true + w.WriteMsg(m) +} + +func BenchmarkServeCompress(b *testing.B) { + b.StopTimer() + HandleFunc("miek.nl.", HelloServerCompress) + defer HandleRemove("miek.nl.") + a := runtime.GOMAXPROCS(4) + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + b.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + c := new(Client) + m := new(Msg) + m.SetQuestion("miek.nl", TypeSOA) + b.StartTimer() + for i := 0; i < b.N; i++ { + c.Exchange(m, addrstr) + } + runtime.GOMAXPROCS(a) +} + +func TestDotAsCatchAllWildcard(t *testing.T) { + mux := NewServeMux() + mux.Handle(".", HandlerFunc(HelloServer)) + mux.Handle("example.com.", HandlerFunc(AnotherHelloServer)) + + handler := mux.match("www.miek.nl.", TypeTXT) + if handler == nil { + t.Error("wildcard match failed") + } + + handler = mux.match("www.example.com.", TypeTXT) + if handler == nil { + t.Error("example.com match failed") + } + + handler = mux.match("a.www.example.com.", TypeTXT) + if handler == nil { + t.Error("a.www.example.com match failed") + } + + handler = mux.match("boe.", TypeTXT) + if handler == nil { + t.Error("boe. match failed") + } +} + +func TestCaseFolding(t *testing.T) { + mux := NewServeMux() + mux.Handle("_udp.example.com.", HandlerFunc(HelloServer)) + + handler := mux.match("_dns._udp.example.com.", TypeSRV) + if handler == nil { + t.Error("case sensitive characters folded") + } + + handler = mux.match("_DNS._UDP.EXAMPLE.COM.", TypeSRV) + if handler == nil { + t.Error("case insensitive characters not folded") + } +} + +func TestRootServer(t *testing.T) { + mux := NewServeMux() + mux.Handle(".", HandlerFunc(HelloServer)) + + handler := mux.match(".", TypeNS) + if handler == nil { + t.Error("root match failed") + } +} + +type maxRec struct { + max int + sync.RWMutex +} + +var M = new(maxRec) + +func HelloServerLargeResponse(resp ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + m.Authoritative = true + m1 := 0 + M.RLock() + m1 = M.max + M.RUnlock() + for i := 0; i < m1; i++ { + aRec := &A{ + Hdr: RR_Header{ + Name: req.Question[0].Name, + Rrtype: TypeA, + Class: ClassINET, + Ttl: 0, + }, + A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)).To4(), + } + m.Answer = append(m.Answer, aRec) + } + resp.WriteMsg(m) +} + +func TestServingLargeResponses(t *testing.T) { + HandleFunc("example.", HelloServerLargeResponse) + defer HandleRemove("example.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + // Create request + m := new(Msg) + m.SetQuestion("web.service.example.", TypeANY) + + c := new(Client) + c.Net = "udp" + M.Lock() + M.max = 2 + M.Unlock() + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Logf("failed to exchange: %s", err.Error()) + t.Fail() + } + // This must fail + M.Lock() + M.max = 20 + M.Unlock() + _, _, err = c.Exchange(m, addrstr) + if err == nil { + t.Logf("failed to fail exchange, this should generate packet error") + t.Fail() + } + // But this must work again + c.UDPSize = 7000 + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Logf("failed to exchange: %s", err.Error()) + t.Fail() + } +} + +func TestShutdownTCP(t *testing.T) { + s, _, err := RunLocalTCPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + // it normally is too early to shutting down because server + // activates in goroutine. + runtime.Gosched() + err = s.Shutdown() + if err != nil { + t.Errorf("Could not shutdown test TCP server, %s", err) + } +} + +func TestShutdownUDP(t *testing.T) { + s, _, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + // it normally is too early to shutting down because server + // activates in goroutine. + runtime.Gosched() + err = s.Shutdown() + if err != nil { + t.Errorf("Could not shutdown test UDP server, %s", err) + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/singleinflight.go b/Godeps/_workspace/src/github.com/miekg/dns/singleinflight.go new file mode 100644 index 0000000000..9573c7d0b8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/singleinflight.go @@ -0,0 +1,57 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Adapted for dns package usage by Miek Gieben. + +package dns + +import "sync" +import "time" + +// call is an in-flight or completed singleflight.Do call +type call struct { + wg sync.WaitGroup + val *Msg + rtt time.Duration + err error + dups int +} + +// singleflight represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type singleflight struct { + sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *singleflight) Do(key string, fn func() (*Msg, time.Duration, error)) (v *Msg, rtt time.Duration, err error, shared bool) { + g.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.Unlock() + c.wg.Wait() + return c.val, c.rtt, c.err, true + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.Unlock() + + c.val, c.rtt, c.err = fn() + c.wg.Done() + + g.Lock() + delete(g.m, key) + g.Unlock() + + return c.val, c.rtt, c.err, c.dups > 0 +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/tlsa.go b/Godeps/_workspace/src/github.com/miekg/dns/tlsa.go new file mode 100644 index 0000000000..d3bc3b0217 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/tlsa.go @@ -0,0 +1,84 @@ +package dns + +import ( + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "encoding/hex" + "errors" + "io" + "net" + "strconv" +) + +// CertificateToDANE converts a certificate to a hex string as used in the TLSA record. +func CertificateToDANE(selector, matchingType uint8, cert *x509.Certificate) (string, error) { + switch matchingType { + case 0: + switch selector { + case 0: + return hex.EncodeToString(cert.Raw), nil + case 1: + return hex.EncodeToString(cert.RawSubjectPublicKeyInfo), nil + } + case 1: + h := sha256.New() + switch selector { + case 0: + return hex.EncodeToString(cert.Raw), nil + case 1: + io.WriteString(h, string(cert.RawSubjectPublicKeyInfo)) + return hex.EncodeToString(h.Sum(nil)), nil + } + case 2: + h := sha512.New() + switch selector { + case 0: + return hex.EncodeToString(cert.Raw), nil + case 1: + io.WriteString(h, string(cert.RawSubjectPublicKeyInfo)) + return hex.EncodeToString(h.Sum(nil)), nil + } + } + return "", errors.New("dns: bad TLSA MatchingType or TLSA Selector") +} + +// Sign creates a TLSA record from an SSL certificate. +func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (err error) { + r.Hdr.Rrtype = TypeTLSA + r.Usage = uint8(usage) + r.Selector = uint8(selector) + r.MatchingType = uint8(matchingType) + + r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) + if err != nil { + return err + } + return nil +} + +// Verify verifies a TLSA record against an SSL certificate. If it is OK +// a nil error is returned. +func (r *TLSA) Verify(cert *x509.Certificate) error { + c, err := CertificateToDANE(r.Selector, r.MatchingType, cert) + if err != nil { + return err // Not also ErrSig? + } + if r.Certificate == c { + return nil + } + return ErrSig // ErrSig, really? +} + +// TLSAName returns the ownername of a TLSA resource record as per the +// rules specified in RFC 6698, Section 3. +func TLSAName(name, service, network string) (string, error) { + if !IsFqdn(name) { + return "", ErrFqdn + } + p, e := net.LookupPort(network, service) + if e != nil { + return "", e + } + return "_" + strconv.Itoa(p) + "_" + network + "." + name, nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/tsig.go b/Godeps/_workspace/src/github.com/miekg/dns/tsig.go new file mode 100644 index 0000000000..f31b15be4a --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/tsig.go @@ -0,0 +1,371 @@ +// TRANSACTION SIGNATURE +// +// An TSIG or transaction signature adds a HMAC TSIG record to each message sent. +// The supported algorithms include: HmacMD5, HmacSHA1 and HmacSHA256. +// +// Basic use pattern when querying with a TSIG name "axfr." (note that these key names +// must be fully qualified - as they are domain names) and the base64 secret +// "so6ZGir4GPAqINNh9U5c3A==": +// +// c := new(dns.Client) +// c.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} +// m := new(dns.Msg) +// m.SetQuestion("miek.nl.", dns.TypeMX) +// m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix()) +// ... +// // When sending the TSIG RR is calculated and filled in before sending +// +// When requesting an zone transfer (almost all TSIG usage is when requesting zone transfers), with +// TSIG, this is the basic use pattern. In this example we request an AXFR for +// miek.nl. with TSIG key named "axfr." and secret "so6ZGir4GPAqINNh9U5c3A==" +// and using the server 176.58.119.54: +// +// t := new(dns.Transfer) +// m := new(dns.Msg) +// t.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} +// m.SetAxfr("miek.nl.") +// m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix()) +// c, err := t.In(m, "176.58.119.54:53") +// for r := range c { /* r.RR */ } +// +// You can now read the records from the transfer as they come in. Each envelope is checked with TSIG. +// If something is not correct an error is returned. +// +// Basic use pattern validating and replying to a message that has TSIG set. +// +// server := &dns.Server{Addr: ":53", Net: "udp"} +// server.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} +// go server.ListenAndServe() +// dns.HandleFunc(".", handleRequest) +// +// func handleRequest(w dns.ResponseWriter, r *dns.Msg) { +// m := new(Msg) +// m.SetReply(r) +// if r.IsTsig() { +// if w.TsigStatus() == nil { +// // *Msg r has an TSIG record and it was validated +// m.SetTsig("axfr.", dns.HmacMD5, 300, time.Now().Unix()) +// } else { +// // *Msg r has an TSIG records and it was not valided +// } +// } +// w.WriteMsg(m) +// } +package dns + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "hash" + "io" + "strconv" + "strings" + "time" +) + +// HMAC hashing codes. These are transmitted as domain names. +const ( + HmacMD5 = "hmac-md5.sig-alg.reg.int." + HmacSHA1 = "hmac-sha1." + HmacSHA256 = "hmac-sha256." +) + +type TSIG struct { + Hdr RR_Header + Algorithm string `dns:"domain-name"` + TimeSigned uint64 `dns:"uint48"` + Fudge uint16 + MACSize uint16 + MAC string `dns:"size-hex"` + OrigId uint16 + Error uint16 + OtherLen uint16 + OtherData string `dns:"size-hex"` +} + +func (rr *TSIG) Header() *RR_Header { + return &rr.Hdr +} + +// TSIG has no official presentation format, but this will suffice. + +func (rr *TSIG) String() string { + s := "\n;; TSIG PSEUDOSECTION:\n" + s += rr.Hdr.String() + + " " + rr.Algorithm + + " " + tsigTimeToString(rr.TimeSigned) + + " " + strconv.Itoa(int(rr.Fudge)) + + " " + strconv.Itoa(int(rr.MACSize)) + + " " + strings.ToUpper(rr.MAC) + + " " + strconv.Itoa(int(rr.OrigId)) + + " " + strconv.Itoa(int(rr.Error)) + // BIND prints NOERROR + " " + strconv.Itoa(int(rr.OtherLen)) + + " " + rr.OtherData + return s +} + +func (rr *TSIG) len() int { + return rr.Hdr.len() + len(rr.Algorithm) + 1 + 6 + + 4 + len(rr.MAC)/2 + 1 + 6 + len(rr.OtherData)/2 + 1 +} + +func (rr *TSIG) copy() RR { + return &TSIG{*rr.Hdr.copyHeader(), rr.Algorithm, rr.TimeSigned, rr.Fudge, rr.MACSize, rr.MAC, rr.OrigId, rr.Error, rr.OtherLen, rr.OtherData} +} + +// The following values must be put in wireformat, so that the MAC can be calculated. +// RFC 2845, section 3.4.2. TSIG Variables. +type tsigWireFmt struct { + // From RR_Header + Name string `dns:"domain-name"` + Class uint16 + Ttl uint32 + // Rdata of the TSIG + Algorithm string `dns:"domain-name"` + TimeSigned uint64 `dns:"uint48"` + Fudge uint16 + // MACSize, MAC and OrigId excluded + Error uint16 + OtherLen uint16 + OtherData string `dns:"size-hex"` +} + +// If we have the MAC use this type to convert it to wiredata. +// Section 3.4.3. Request MAC +type macWireFmt struct { + MACSize uint16 + MAC string `dns:"size-hex"` +} + +// 3.3. Time values used in TSIG calculations +type timerWireFmt struct { + TimeSigned uint64 `dns:"uint48"` + Fudge uint16 +} + +// TsigGenerate fills out the TSIG record attached to the message. +// The message should contain +// a "stub" TSIG RR with the algorithm, key name (owner name of the RR), +// time fudge (defaults to 300 seconds) and the current time +// The TSIG MAC is saved in that Tsig RR. +// When TsigGenerate is called for the first time requestMAC is set to the empty string and +// timersOnly is false. +// If something goes wrong an error is returned, otherwise it is nil. +func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { + if m.IsTsig() == nil { + panic("dns: TSIG not last RR in additional") + } + // If we barf here, the caller is to blame + rawsecret, err := packBase64([]byte(secret)) + if err != nil { + return nil, "", err + } + + rr := m.Extra[len(m.Extra)-1].(*TSIG) + m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg + mbuf, err := m.Pack() + if err != nil { + return nil, "", err + } + buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly) + + t := new(TSIG) + var h hash.Hash + switch rr.Algorithm { + case HmacMD5: + h = hmac.New(md5.New, []byte(rawsecret)) + case HmacSHA1: + h = hmac.New(sha1.New, []byte(rawsecret)) + case HmacSHA256: + h = hmac.New(sha256.New, []byte(rawsecret)) + default: + return nil, "", ErrKeyAlg + } + io.WriteString(h, string(buf)) + t.MAC = hex.EncodeToString(h.Sum(nil)) + t.MACSize = uint16(len(t.MAC) / 2) // Size is half! + + t.Hdr = RR_Header{Name: rr.Hdr.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0} + t.Fudge = rr.Fudge + t.TimeSigned = rr.TimeSigned + t.Algorithm = rr.Algorithm + t.OrigId = m.Id + + tbuf := make([]byte, t.len()) + if off, err := PackRR(t, tbuf, 0, nil, false); err == nil { + tbuf = tbuf[:off] // reset to actual size used + } else { + return nil, "", err + } + mbuf = append(mbuf, tbuf...) + rawSetExtraLen(mbuf, uint16(len(m.Extra)+1)) + return mbuf, t.MAC, nil +} + +// TsigVerify verifies the TSIG on a message. +// If the signature does not validate err contains the +// error, otherwise it is nil. +func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { + rawsecret, err := packBase64([]byte(secret)) + if err != nil { + return err + } + // Strip the TSIG from the incoming msg + stripped, tsig, err := stripTsig(msg) + if err != nil { + return err + } + + msgMAC, err := hex.DecodeString(tsig.MAC) + if err != nil { + return err + } + + buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly) + ti := uint64(time.Now().Unix()) - tsig.TimeSigned + if uint64(tsig.Fudge) < ti { + return ErrTime + } + + var h hash.Hash + switch tsig.Algorithm { + case HmacMD5: + h = hmac.New(md5.New, rawsecret) + case HmacSHA1: + h = hmac.New(sha1.New, rawsecret) + case HmacSHA256: + h = hmac.New(sha256.New, rawsecret) + default: + return ErrKeyAlg + } + h.Write(buf) + if !hmac.Equal(h.Sum(nil), msgMAC) { + return ErrSig + } + return nil +} + +// Create a wiredata buffer for the MAC calculation. +func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []byte { + var buf []byte + if rr.TimeSigned == 0 { + rr.TimeSigned = uint64(time.Now().Unix()) + } + if rr.Fudge == 0 { + rr.Fudge = 300 // Standard (RFC) default. + } + + if requestMAC != "" { + m := new(macWireFmt) + m.MACSize = uint16(len(requestMAC) / 2) + m.MAC = requestMAC + buf = make([]byte, len(requestMAC)) // long enough + n, _ := PackStruct(m, buf, 0) + buf = buf[:n] + } + + tsigvar := make([]byte, DefaultMsgSize) + if timersOnly { + tsig := new(timerWireFmt) + tsig.TimeSigned = rr.TimeSigned + tsig.Fudge = rr.Fudge + n, _ := PackStruct(tsig, tsigvar, 0) + tsigvar = tsigvar[:n] + } else { + tsig := new(tsigWireFmt) + tsig.Name = strings.ToLower(rr.Hdr.Name) + tsig.Class = ClassANY + tsig.Ttl = rr.Hdr.Ttl + tsig.Algorithm = strings.ToLower(rr.Algorithm) + tsig.TimeSigned = rr.TimeSigned + tsig.Fudge = rr.Fudge + tsig.Error = rr.Error + tsig.OtherLen = rr.OtherLen + tsig.OtherData = rr.OtherData + n, _ := PackStruct(tsig, tsigvar, 0) + tsigvar = tsigvar[:n] + } + + if requestMAC != "" { + x := append(buf, msgbuf...) + buf = append(x, tsigvar...) + } else { + buf = append(msgbuf, tsigvar...) + } + return buf +} + +// Strip the TSIG from the raw message. +func stripTsig(msg []byte) ([]byte, *TSIG, error) { + // Copied from msg.go's Unpack() + // Header. + var dh Header + var err error + dns := new(Msg) + rr := new(TSIG) + off := 0 + tsigoff := 0 + if off, err = UnpackStruct(&dh, msg, off); err != nil { + return nil, nil, err + } + if dh.Arcount == 0 { + return nil, nil, ErrNoSig + } + // Rcode, see msg.go Unpack() + if int(dh.Bits&0xF) == RcodeNotAuth { + return nil, nil, ErrAuth + } + + // Arrays. + dns.Question = make([]Question, dh.Qdcount) + dns.Answer = make([]RR, dh.Ancount) + dns.Ns = make([]RR, dh.Nscount) + dns.Extra = make([]RR, dh.Arcount) + + for i := 0; i < len(dns.Question); i++ { + off, err = UnpackStruct(&dns.Question[i], msg, off) + if err != nil { + return nil, nil, err + } + } + for i := 0; i < len(dns.Answer); i++ { + dns.Answer[i], off, err = UnpackRR(msg, off) + if err != nil { + return nil, nil, err + } + } + for i := 0; i < len(dns.Ns); i++ { + dns.Ns[i], off, err = UnpackRR(msg, off) + if err != nil { + return nil, nil, err + } + } + for i := 0; i < len(dns.Extra); i++ { + tsigoff = off + dns.Extra[i], off, err = UnpackRR(msg, off) + if err != nil { + return nil, nil, err + } + if dns.Extra[i].Header().Rrtype == TypeTSIG { + rr = dns.Extra[i].(*TSIG) + // Adjust Arcount. + arcount, _ := unpackUint16(msg, 10) + msg[10], msg[11] = packUint16(arcount - 1) + break + } + } + if rr == nil { + return nil, nil, ErrNoSig + } + return msg[:tsigoff], rr, nil +} + +// Translate the TSIG time signed into a date. There is no +// need for RFC1982 calculations as this date is 48 bits. +func tsigTimeToString(t uint64) string { + ti := time.Unix(int64(t), 0).UTC() + return ti.Format("20060102150405") +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/types.go b/Godeps/_workspace/src/github.com/miekg/dns/types.go new file mode 100644 index 0000000000..d5b5be4065 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/types.go @@ -0,0 +1,1685 @@ +package dns + +import ( + "encoding/base64" + "fmt" + "net" + "strconv" + "strings" + "time" +) + +type ( + Type uint16 // Type is a DNS type. + Class uint16 // Class is a DNS class. + Name string // Name is a DNS domain name. +) + +// Packet formats + +// Wire constants and supported types. +const ( + // valid RR_Header.Rrtype and Question.qtype + TypeNone uint16 = 0 + TypeA uint16 = 1 + TypeNS uint16 = 2 + TypeMD uint16 = 3 + TypeMF uint16 = 4 + TypeCNAME uint16 = 5 + TypeSOA uint16 = 6 + TypeMB uint16 = 7 + TypeMG uint16 = 8 + TypeMR uint16 = 9 + TypeNULL uint16 = 10 + TypeWKS uint16 = 11 + TypePTR uint16 = 12 + TypeHINFO uint16 = 13 + TypeMINFO uint16 = 14 + TypeMX uint16 = 15 + TypeTXT uint16 = 16 + TypeRP uint16 = 17 + TypeAFSDB uint16 = 18 + TypeX25 uint16 = 19 + TypeISDN uint16 = 20 + TypeRT uint16 = 21 + TypeNSAP uint16 = 22 + TypeNSAPPTR uint16 = 23 + TypeSIG uint16 = 24 + TypeKEY uint16 = 25 + TypePX uint16 = 26 + TypeGPOS uint16 = 27 + TypeAAAA uint16 = 28 + TypeLOC uint16 = 29 + TypeNXT uint16 = 30 + TypeEID uint16 = 31 + TypeNIMLOC uint16 = 32 + TypeSRV uint16 = 33 + TypeATMA uint16 = 34 + TypeNAPTR uint16 = 35 + TypeKX uint16 = 36 + TypeCERT uint16 = 37 + TypeDNAME uint16 = 39 + TypeOPT uint16 = 41 // EDNS + TypeDS uint16 = 43 + TypeSSHFP uint16 = 44 + TypeIPSECKEY uint16 = 45 + TypeRRSIG uint16 = 46 + TypeNSEC uint16 = 47 + TypeDNSKEY uint16 = 48 + TypeDHCID uint16 = 49 + TypeNSEC3 uint16 = 50 + TypeNSEC3PARAM uint16 = 51 + TypeTLSA uint16 = 52 + TypeHIP uint16 = 55 + TypeNINFO uint16 = 56 + TypeRKEY uint16 = 57 + TypeTALINK uint16 = 58 + TypeCDS uint16 = 59 + TypeOPENPGPKEY uint16 = 61 + TypeSPF uint16 = 99 + TypeUINFO uint16 = 100 + TypeUID uint16 = 101 + TypeGID uint16 = 102 + TypeUNSPEC uint16 = 103 + TypeNID uint16 = 104 + TypeL32 uint16 = 105 + TypeL64 uint16 = 106 + TypeLP uint16 = 107 + TypeEUI48 uint16 = 108 + TypeEUI64 uint16 = 109 + + TypeTKEY uint16 = 249 + TypeTSIG uint16 = 250 + // valid Question.Qtype only + TypeIXFR uint16 = 251 + TypeAXFR uint16 = 252 + TypeMAILB uint16 = 253 + TypeMAILA uint16 = 254 + TypeANY uint16 = 255 + + TypeURI uint16 = 256 + TypeCAA uint16 = 257 + TypeTA uint16 = 32768 + TypeDLV uint16 = 32769 + TypeReserved uint16 = 65535 + + // valid Question.Qclass + ClassINET = 1 + ClassCSNET = 2 + ClassCHAOS = 3 + ClassHESIOD = 4 + ClassNONE = 254 + ClassANY = 255 + + // Msg.rcode + RcodeSuccess = 0 + RcodeFormatError = 1 + RcodeServerFailure = 2 + RcodeNameError = 3 + RcodeNotImplemented = 4 + RcodeRefused = 5 + RcodeYXDomain = 6 + RcodeYXRrset = 7 + RcodeNXRrset = 8 + RcodeNotAuth = 9 + RcodeNotZone = 10 + RcodeBadSig = 16 // TSIG + RcodeBadVers = 16 // EDNS0 + RcodeBadKey = 17 + RcodeBadTime = 18 + RcodeBadMode = 19 // TKEY + RcodeBadName = 20 + RcodeBadAlg = 21 + RcodeBadTrunc = 22 // TSIG + + // Opcode + OpcodeQuery = 0 + OpcodeIQuery = 1 + OpcodeStatus = 2 + // There is no 3 + OpcodeNotify = 4 + OpcodeUpdate = 5 +) + +// The wire format for the DNS packet header. +type Header struct { + Id uint16 + Bits uint16 + Qdcount, Ancount, Nscount, Arcount uint16 +} + +const ( + // Header.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available + _Z = 1 << 6 // Z + _AD = 1 << 5 // authticated data + _CD = 1 << 4 // checking disabled + +) + +// RFC 1876, Section 2 +const _LOC_EQUATOR = 1 << 31 + +// RFC 4398, Section 2.1 +const ( + CertPKIX = 1 + iota + CertSPKI + CertPGP + CertIPIX + CertISPKI + CertIPGP + CertACPKIX + CertIACPKIX + CertURI = 253 + CertOID = 254 +) + +var CertTypeToString = map[uint16]string{ + CertPKIX: "PKIX", + CertSPKI: "SPKI", + CertPGP: "PGP", + CertIPIX: "IPIX", + CertISPKI: "ISPKI", + CertIPGP: "IPGP", + CertACPKIX: "ACPKIX", + CertIACPKIX: "IACPKIX", + CertURI: "URI", + CertOID: "OID", +} + +var StringToCertType = reverseInt16(CertTypeToString) + +// DNS queries. +type Question struct { + Name string `dns:"cdomain-name"` // "cdomain-name" specifies encoding (and may be compressed) + Qtype uint16 + Qclass uint16 +} + +func (q *Question) String() (s string) { + // prefix with ; (as in dig) + s = ";" + sprintName(q.Name) + "\t" + s += Class(q.Qclass).String() + "\t" + s += " " + Type(q.Qtype).String() + return s +} + +func (q *Question) len() int { + l := len(q.Name) + 1 + return l + 4 +} + +type ANY struct { + Hdr RR_Header + // Does not have any rdata +} + +func (rr *ANY) Header() *RR_Header { return &rr.Hdr } +func (rr *ANY) copy() RR { return &ANY{*rr.Hdr.copyHeader()} } +func (rr *ANY) String() string { return rr.Hdr.String() } +func (rr *ANY) len() int { return rr.Hdr.len() } + +type CNAME struct { + Hdr RR_Header + Target string `dns:"cdomain-name"` +} + +func (rr *CNAME) Header() *RR_Header { return &rr.Hdr } +func (rr *CNAME) copy() RR { return &CNAME{*rr.Hdr.copyHeader(), sprintName(rr.Target)} } +func (rr *CNAME) String() string { return rr.Hdr.String() + rr.Target } +func (rr *CNAME) len() int { return rr.Hdr.len() + len(rr.Target) + 1 } + +type HINFO struct { + Hdr RR_Header + Cpu string + Os string +} + +func (rr *HINFO) Header() *RR_Header { return &rr.Hdr } +func (rr *HINFO) copy() RR { return &HINFO{*rr.Hdr.copyHeader(), rr.Cpu, rr.Os} } +func (rr *HINFO) String() string { return rr.Hdr.String() + rr.Cpu + " " + rr.Os } +func (rr *HINFO) len() int { return rr.Hdr.len() + len(rr.Cpu) + len(rr.Os) } + +type MB struct { + Hdr RR_Header + Mb string `dns:"cdomain-name"` +} + +func (rr *MB) Header() *RR_Header { return &rr.Hdr } +func (rr *MB) copy() RR { return &MB{*rr.Hdr.copyHeader(), sprintName(rr.Mb)} } + +func (rr *MB) String() string { return rr.Hdr.String() + rr.Mb } +func (rr *MB) len() int { return rr.Hdr.len() + len(rr.Mb) + 1 } + +type MG struct { + Hdr RR_Header + Mg string `dns:"cdomain-name"` +} + +func (rr *MG) Header() *RR_Header { return &rr.Hdr } +func (rr *MG) copy() RR { return &MG{*rr.Hdr.copyHeader(), rr.Mg} } +func (rr *MG) len() int { l := len(rr.Mg) + 1; return rr.Hdr.len() + l } +func (rr *MG) String() string { return rr.Hdr.String() + sprintName(rr.Mg) } + +type MINFO struct { + Hdr RR_Header + Rmail string `dns:"cdomain-name"` + Email string `dns:"cdomain-name"` +} + +func (rr *MINFO) Header() *RR_Header { return &rr.Hdr } +func (rr *MINFO) copy() RR { return &MINFO{*rr.Hdr.copyHeader(), rr.Rmail, rr.Email} } + +func (rr *MINFO) String() string { + return rr.Hdr.String() + sprintName(rr.Rmail) + " " + sprintName(rr.Email) +} + +func (rr *MINFO) len() int { + l := len(rr.Rmail) + 1 + n := len(rr.Email) + 1 + return rr.Hdr.len() + l + n +} + +type MR struct { + Hdr RR_Header + Mr string `dns:"cdomain-name"` +} + +func (rr *MR) Header() *RR_Header { return &rr.Hdr } +func (rr *MR) copy() RR { return &MR{*rr.Hdr.copyHeader(), rr.Mr} } +func (rr *MR) len() int { l := len(rr.Mr) + 1; return rr.Hdr.len() + l } + +func (rr *MR) String() string { + return rr.Hdr.String() + sprintName(rr.Mr) +} + +type MF struct { + Hdr RR_Header + Mf string `dns:"cdomain-name"` +} + +func (rr *MF) Header() *RR_Header { return &rr.Hdr } +func (rr *MF) copy() RR { return &MF{*rr.Hdr.copyHeader(), rr.Mf} } +func (rr *MF) len() int { return rr.Hdr.len() + len(rr.Mf) + 1 } + +func (rr *MF) String() string { + return rr.Hdr.String() + " " + sprintName(rr.Mf) +} + +type MD struct { + Hdr RR_Header + Md string `dns:"cdomain-name"` +} + +func (rr *MD) Header() *RR_Header { return &rr.Hdr } +func (rr *MD) copy() RR { return &MD{*rr.Hdr.copyHeader(), rr.Md} } +func (rr *MD) len() int { return rr.Hdr.len() + len(rr.Md) + 1 } + +func (rr *MD) String() string { + return rr.Hdr.String() + " " + sprintName(rr.Md) +} + +type MX struct { + Hdr RR_Header + Preference uint16 + Mx string `dns:"cdomain-name"` +} + +func (rr *MX) Header() *RR_Header { return &rr.Hdr } +func (rr *MX) copy() RR { return &MX{*rr.Hdr.copyHeader(), rr.Preference, rr.Mx} } +func (rr *MX) len() int { l := len(rr.Mx) + 1; return rr.Hdr.len() + l + 2 } + +func (rr *MX) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + " " + sprintName(rr.Mx) +} + +type AFSDB struct { + Hdr RR_Header + Subtype uint16 + Hostname string `dns:"cdomain-name"` +} + +func (rr *AFSDB) Header() *RR_Header { return &rr.Hdr } +func (rr *AFSDB) copy() RR { return &AFSDB{*rr.Hdr.copyHeader(), rr.Subtype, rr.Hostname} } +func (rr *AFSDB) len() int { l := len(rr.Hostname) + 1; return rr.Hdr.len() + l + 2 } + +func (rr *AFSDB) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Subtype)) + " " + sprintName(rr.Hostname) +} + +type X25 struct { + Hdr RR_Header + PSDNAddress string +} + +func (rr *X25) Header() *RR_Header { return &rr.Hdr } +func (rr *X25) copy() RR { return &X25{*rr.Hdr.copyHeader(), rr.PSDNAddress} } +func (rr *X25) len() int { return rr.Hdr.len() + len(rr.PSDNAddress) + 1 } + +func (rr *X25) String() string { + return rr.Hdr.String() + rr.PSDNAddress +} + +type RT struct { + Hdr RR_Header + Preference uint16 + Host string `dns:"cdomain-name"` +} + +func (rr *RT) Header() *RR_Header { return &rr.Hdr } +func (rr *RT) copy() RR { return &RT{*rr.Hdr.copyHeader(), rr.Preference, rr.Host} } +func (rr *RT) len() int { l := len(rr.Host) + 1; return rr.Hdr.len() + l + 2 } + +func (rr *RT) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + " " + sprintName(rr.Host) +} + +type NS struct { + Hdr RR_Header + Ns string `dns:"cdomain-name"` +} + +func (rr *NS) Header() *RR_Header { return &rr.Hdr } +func (rr *NS) len() int { l := len(rr.Ns) + 1; return rr.Hdr.len() + l } +func (rr *NS) copy() RR { return &NS{*rr.Hdr.copyHeader(), rr.Ns} } + +func (rr *NS) String() string { + return rr.Hdr.String() + sprintName(rr.Ns) +} + +type PTR struct { + Hdr RR_Header + Ptr string `dns:"cdomain-name"` +} + +func (rr *PTR) Header() *RR_Header { return &rr.Hdr } +func (rr *PTR) copy() RR { return &PTR{*rr.Hdr.copyHeader(), rr.Ptr} } +func (rr *PTR) len() int { l := len(rr.Ptr) + 1; return rr.Hdr.len() + l } + +func (rr *PTR) String() string { + return rr.Hdr.String() + sprintName(rr.Ptr) +} + +type RP struct { + Hdr RR_Header + Mbox string `dns:"domain-name"` + Txt string `dns:"domain-name"` +} + +func (rr *RP) Header() *RR_Header { return &rr.Hdr } +func (rr *RP) copy() RR { return &RP{*rr.Hdr.copyHeader(), rr.Mbox, rr.Txt} } +func (rr *RP) len() int { return rr.Hdr.len() + len(rr.Mbox) + 1 + len(rr.Txt) + 1 } + +func (rr *RP) String() string { + return rr.Hdr.String() + rr.Mbox + " " + sprintTxt([]string{rr.Txt}) +} + +type SOA struct { + Hdr RR_Header + Ns string `dns:"cdomain-name"` + Mbox string `dns:"cdomain-name"` + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minttl uint32 +} + +func (rr *SOA) Header() *RR_Header { return &rr.Hdr } +func (rr *SOA) copy() RR { + return &SOA{*rr.Hdr.copyHeader(), rr.Ns, rr.Mbox, rr.Serial, rr.Refresh, rr.Retry, rr.Expire, rr.Minttl} +} + +func (rr *SOA) String() string { + return rr.Hdr.String() + sprintName(rr.Ns) + " " + sprintName(rr.Mbox) + + " " + strconv.FormatInt(int64(rr.Serial), 10) + + " " + strconv.FormatInt(int64(rr.Refresh), 10) + + " " + strconv.FormatInt(int64(rr.Retry), 10) + + " " + strconv.FormatInt(int64(rr.Expire), 10) + + " " + strconv.FormatInt(int64(rr.Minttl), 10) +} + +func (rr *SOA) len() int { + l := len(rr.Ns) + 1 + n := len(rr.Mbox) + 1 + return rr.Hdr.len() + l + n + 20 +} + +type TXT struct { + Hdr RR_Header + Txt []string `dns:"txt"` +} + +func (rr *TXT) Header() *RR_Header { return &rr.Hdr } +func (rr *TXT) copy() RR { + cp := make([]string, len(rr.Txt), cap(rr.Txt)) + copy(cp, rr.Txt) + return &TXT{*rr.Hdr.copyHeader(), cp} +} + +func (rr *TXT) String() string { return rr.Hdr.String() + sprintTxt(rr.Txt) } + +func sprintName(s string) string { + src := []byte(s) + dst := make([]byte, 0, len(src)) + for i := 0; i < len(src); { + if i+1 < len(src) && src[i] == '\\' && src[i+1] == '.' { + dst = append(dst, src[i:i+2]...) + i += 2 + } else { + b, n := nextByte(src, i) + if n == 0 { + i++ // dangling back slash + } else if b == '.' { + dst = append(dst, b) + } else { + dst = appendDomainNameByte(dst, b) + } + i += n + } + } + return string(dst) +} + +func sprintTxt(txt []string) string { + var out []byte + for i, s := range txt { + if i > 0 { + out = append(out, ` "`...) + } else { + out = append(out, '"') + } + bs := []byte(s) + for j := 0; j < len(bs); { + b, n := nextByte(bs, j) + if n == 0 { + break + } + out = appendTXTStringByte(out, b) + j += n + } + out = append(out, '"') + } + return string(out) +} + +func appendDomainNameByte(s []byte, b byte) []byte { + switch b { + case '.', ' ', '\'', '@', ';', '(', ')': // additional chars to escape + return append(s, '\\', b) + } + return appendTXTStringByte(s, b) +} + +func appendTXTStringByte(s []byte, b byte) []byte { + switch b { + case '\t': + return append(s, '\\', 't') + case '\r': + return append(s, '\\', 'r') + case '\n': + return append(s, '\\', 'n') + case '"', '\\': + return append(s, '\\', b) + } + if b < ' ' || b > '~' { + return append(s, fmt.Sprintf("\\%03d", b)...) + } + return append(s, b) +} + +func nextByte(b []byte, offset int) (byte, int) { + if offset >= len(b) { + return 0, 0 + } + if b[offset] != '\\' { + // not an escape sequence + return b[offset], 1 + } + switch len(b) - offset { + case 1: // dangling escape + return 0, 0 + case 2, 3: // too short to be \ddd + default: // maybe \ddd + if isDigit(b[offset+1]) && isDigit(b[offset+2]) && isDigit(b[offset+3]) { + return dddToByte(b[offset+1:]), 4 + } + } + // not \ddd, maybe a control char + switch b[offset+1] { + case 't': + return '\t', 2 + case 'r': + return '\r', 2 + case 'n': + return '\n', 2 + default: + return b[offset+1], 2 + } +} + +func (rr *TXT) len() int { + l := rr.Hdr.len() + for _, t := range rr.Txt { + l += len(t) + 1 + } + return l +} + +type SPF struct { + Hdr RR_Header + Txt []string `dns:"txt"` +} + +func (rr *SPF) Header() *RR_Header { return &rr.Hdr } +func (rr *SPF) copy() RR { + cp := make([]string, len(rr.Txt), cap(rr.Txt)) + copy(cp, rr.Txt) + return &SPF{*rr.Hdr.copyHeader(), cp} +} + +func (rr *SPF) String() string { return rr.Hdr.String() + sprintTxt(rr.Txt) } + +func (rr *SPF) len() int { + l := rr.Hdr.len() + for _, t := range rr.Txt { + l += len(t) + 1 + } + return l +} + +type SRV struct { + Hdr RR_Header + Priority uint16 + Weight uint16 + Port uint16 + Target string `dns:"domain-name"` +} + +func (rr *SRV) Header() *RR_Header { return &rr.Hdr } +func (rr *SRV) len() int { l := len(rr.Target) + 1; return rr.Hdr.len() + l + 6 } +func (rr *SRV) copy() RR { + return &SRV{*rr.Hdr.copyHeader(), rr.Priority, rr.Weight, rr.Port, rr.Target} +} + +func (rr *SRV) String() string { + return rr.Hdr.String() + + strconv.Itoa(int(rr.Priority)) + " " + + strconv.Itoa(int(rr.Weight)) + " " + + strconv.Itoa(int(rr.Port)) + " " + sprintName(rr.Target) +} + +type NAPTR struct { + Hdr RR_Header + Order uint16 + Preference uint16 + Flags string + Service string + Regexp string + Replacement string `dns:"domain-name"` +} + +func (rr *NAPTR) Header() *RR_Header { return &rr.Hdr } +func (rr *NAPTR) copy() RR { + return &NAPTR{*rr.Hdr.copyHeader(), rr.Order, rr.Preference, rr.Flags, rr.Service, rr.Regexp, rr.Replacement} +} + +func (rr *NAPTR) String() string { + return rr.Hdr.String() + + strconv.Itoa(int(rr.Order)) + " " + + strconv.Itoa(int(rr.Preference)) + " " + + "\"" + rr.Flags + "\" " + + "\"" + rr.Service + "\" " + + "\"" + rr.Regexp + "\" " + + rr.Replacement +} + +func (rr *NAPTR) len() int { + return rr.Hdr.len() + 4 + len(rr.Flags) + 1 + len(rr.Service) + 1 + + len(rr.Regexp) + 1 + len(rr.Replacement) + 1 +} + +// See RFC 4398. +type CERT struct { + Hdr RR_Header + Type uint16 + KeyTag uint16 + Algorithm uint8 + Certificate string `dns:"base64"` +} + +func (rr *CERT) Header() *RR_Header { return &rr.Hdr } +func (rr *CERT) copy() RR { + return &CERT{*rr.Hdr.copyHeader(), rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate} +} + +func (rr *CERT) String() string { + var ( + ok bool + certtype, algorithm string + ) + if certtype, ok = CertTypeToString[rr.Type]; !ok { + certtype = strconv.Itoa(int(rr.Type)) + } + if algorithm, ok = AlgorithmToString[rr.Algorithm]; !ok { + algorithm = strconv.Itoa(int(rr.Algorithm)) + } + return rr.Hdr.String() + certtype + + " " + strconv.Itoa(int(rr.KeyTag)) + + " " + algorithm + + " " + rr.Certificate +} + +func (rr *CERT) len() int { + return rr.Hdr.len() + 5 + + base64.StdEncoding.DecodedLen(len(rr.Certificate)) +} + +// See RFC 2672. +type DNAME struct { + Hdr RR_Header + Target string `dns:"domain-name"` +} + +func (rr *DNAME) Header() *RR_Header { return &rr.Hdr } +func (rr *DNAME) copy() RR { return &DNAME{*rr.Hdr.copyHeader(), rr.Target} } +func (rr *DNAME) len() int { l := len(rr.Target) + 1; return rr.Hdr.len() + l } + +func (rr *DNAME) String() string { + return rr.Hdr.String() + sprintName(rr.Target) +} + +type A struct { + Hdr RR_Header + A net.IP `dns:"a"` +} + +func (rr *A) Header() *RR_Header { return &rr.Hdr } +func (rr *A) copy() RR { return &A{*rr.Hdr.copyHeader(), copyIP(rr.A)} } +func (rr *A) len() int { return rr.Hdr.len() + net.IPv4len } + +func (rr *A) String() string { + if rr.A == nil { + return rr.Hdr.String() + } + return rr.Hdr.String() + rr.A.String() +} + +type AAAA struct { + Hdr RR_Header + AAAA net.IP `dns:"aaaa"` +} + +func (rr *AAAA) Header() *RR_Header { return &rr.Hdr } +func (rr *AAAA) copy() RR { return &AAAA{*rr.Hdr.copyHeader(), copyIP(rr.AAAA)} } +func (rr *AAAA) len() int { return rr.Hdr.len() + net.IPv6len } + +func (rr *AAAA) String() string { + if rr.AAAA == nil { + return rr.Hdr.String() + } + return rr.Hdr.String() + rr.AAAA.String() +} + +type PX struct { + Hdr RR_Header + Preference uint16 + Map822 string `dns:"domain-name"` + Mapx400 string `dns:"domain-name"` +} + +func (rr *PX) Header() *RR_Header { return &rr.Hdr } +func (rr *PX) copy() RR { return &PX{*rr.Hdr.copyHeader(), rr.Preference, rr.Map822, rr.Mapx400} } +func (rr *PX) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + " " + sprintName(rr.Map822) + " " + sprintName(rr.Mapx400) +} +func (rr *PX) len() int { return rr.Hdr.len() + 2 + len(rr.Map822) + 1 + len(rr.Mapx400) + 1 } + +type GPOS struct { + Hdr RR_Header + Longitude string + Latitude string + Altitude string +} + +func (rr *GPOS) Header() *RR_Header { return &rr.Hdr } +func (rr *GPOS) copy() RR { return &GPOS{*rr.Hdr.copyHeader(), rr.Longitude, rr.Latitude, rr.Altitude} } +func (rr *GPOS) len() int { + return rr.Hdr.len() + len(rr.Longitude) + len(rr.Latitude) + len(rr.Altitude) + 3 +} +func (rr *GPOS) String() string { + return rr.Hdr.String() + rr.Longitude + " " + rr.Latitude + " " + rr.Altitude +} + +type LOC struct { + Hdr RR_Header + Version uint8 + Size uint8 + HorizPre uint8 + VertPre uint8 + Latitude uint32 + Longitude uint32 + Altitude uint32 +} + +func (rr *LOC) Header() *RR_Header { return &rr.Hdr } +func (rr *LOC) len() int { return rr.Hdr.len() + 4 + 12 } +func (rr *LOC) copy() RR { + return &LOC{*rr.Hdr.copyHeader(), rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude} +} + +func (rr *LOC) String() string { + s := rr.Hdr.String() + // Copied from ldns + // Latitude + lat := rr.Latitude + north := "N" + if lat > _LOC_EQUATOR { + lat = lat - _LOC_EQUATOR + } else { + north = "S" + lat = _LOC_EQUATOR - lat + } + h := lat / (1000 * 60 * 60) + lat = lat % (1000 * 60 * 60) + m := lat / (1000 * 60) + lat = lat % (1000 * 60) + s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float32(lat) / 1000), north) + // Longitude + lon := rr.Longitude + east := "E" + if lon > _LOC_EQUATOR { + lon = lon - _LOC_EQUATOR + } else { + east = "W" + lon = _LOC_EQUATOR - lon + } + h = lon / (1000 * 60 * 60) + lon = lon % (1000 * 60 * 60) + m = lon / (1000 * 60) + lon = lon % (1000 * 60) + s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float32(lon) / 1000), east) + + s1 := rr.Altitude / 100.00 + s1 -= 100000 + if rr.Altitude%100 == 0 { + s += fmt.Sprintf("%.2fm ", float32(s1)) + } else { + s += fmt.Sprintf("%.0fm ", float32(s1)) + } + s += cmToString((rr.Size&0xf0)>>4, rr.Size&0x0f) + "m " + s += cmToString((rr.HorizPre&0xf0)>>4, rr.HorizPre&0x0f) + "m " + s += cmToString((rr.VertPre&0xf0)>>4, rr.VertPre&0x0f) + "m" + return s +} + +type RRSIG struct { + Hdr RR_Header + TypeCovered uint16 + Algorithm uint8 + Labels uint8 + OrigTtl uint32 + Expiration uint32 + Inception uint32 + KeyTag uint16 + SignerName string `dns:"domain-name"` + Signature string `dns:"base64"` +} + +func (rr *RRSIG) Header() *RR_Header { return &rr.Hdr } +func (rr *RRSIG) copy() RR { + return &RRSIG{*rr.Hdr.copyHeader(), rr.TypeCovered, rr.Algorithm, rr.Labels, rr.OrigTtl, rr.Expiration, rr.Inception, rr.KeyTag, rr.SignerName, rr.Signature} +} + +func (rr *RRSIG) String() string { + s := rr.Hdr.String() + s += Type(rr.TypeCovered).String() + s += " " + strconv.Itoa(int(rr.Algorithm)) + + " " + strconv.Itoa(int(rr.Labels)) + + " " + strconv.FormatInt(int64(rr.OrigTtl), 10) + + " " + TimeToString(rr.Expiration) + + " " + TimeToString(rr.Inception) + + " " + strconv.Itoa(int(rr.KeyTag)) + + " " + sprintName(rr.SignerName) + + " " + rr.Signature + return s +} + +func (rr *RRSIG) len() int { + return rr.Hdr.len() + len(rr.SignerName) + 1 + + base64.StdEncoding.DecodedLen(len(rr.Signature)) + 18 +} + +type NSEC struct { + Hdr RR_Header + NextDomain string `dns:"domain-name"` + TypeBitMap []uint16 `dns:"nsec"` +} + +func (rr *NSEC) Header() *RR_Header { return &rr.Hdr } +func (rr *NSEC) copy() RR { + cp := make([]uint16, len(rr.TypeBitMap), cap(rr.TypeBitMap)) + copy(cp, rr.TypeBitMap) + return &NSEC{*rr.Hdr.copyHeader(), rr.NextDomain, cp} +} + +func (rr *NSEC) String() string { + s := rr.Hdr.String() + sprintName(rr.NextDomain) + for i := 0; i < len(rr.TypeBitMap); i++ { + s += " " + Type(rr.TypeBitMap[i]).String() + } + return s +} + +func (rr *NSEC) len() int { + l := rr.Hdr.len() + len(rr.NextDomain) + 1 + lastwindow := uint32(2 ^ 32 + 1) + for _, t := range rr.TypeBitMap { + window := t / 256 + if uint32(window) != lastwindow { + l += 1 + 32 + } + lastwindow = uint32(window) + } + return l +} + +type DS struct { + Hdr RR_Header + KeyTag uint16 + Algorithm uint8 + DigestType uint8 + Digest string `dns:"hex"` +} + +func (rr *DS) Header() *RR_Header { return &rr.Hdr } +func (rr *DS) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 } +func (rr *DS) copy() RR { + return &DS{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest} +} + +func (rr *DS) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) + + " " + strconv.Itoa(int(rr.Algorithm)) + + " " + strconv.Itoa(int(rr.DigestType)) + + " " + strings.ToUpper(rr.Digest) +} + +type CDS struct { + Hdr RR_Header + KeyTag uint16 + Algorithm uint8 + DigestType uint8 + Digest string `dns:"hex"` +} + +func (rr *CDS) Header() *RR_Header { return &rr.Hdr } +func (rr *CDS) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 } +func (rr *CDS) copy() RR { + return &CDS{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest} +} + +func (rr *CDS) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) + + " " + strconv.Itoa(int(rr.Algorithm)) + + " " + strconv.Itoa(int(rr.DigestType)) + + " " + strings.ToUpper(rr.Digest) +} + +type DLV struct { + Hdr RR_Header + KeyTag uint16 + Algorithm uint8 + DigestType uint8 + Digest string `dns:"hex"` +} + +func (rr *DLV) Header() *RR_Header { return &rr.Hdr } +func (rr *DLV) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 } +func (rr *DLV) copy() RR { + return &DLV{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest} +} + +func (rr *DLV) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) + + " " + strconv.Itoa(int(rr.Algorithm)) + + " " + strconv.Itoa(int(rr.DigestType)) + + " " + strings.ToUpper(rr.Digest) +} + +type KX struct { + Hdr RR_Header + Preference uint16 + Exchanger string `dns:"domain-name"` +} + +func (rr *KX) Header() *RR_Header { return &rr.Hdr } +func (rr *KX) len() int { return rr.Hdr.len() + 2 + len(rr.Exchanger) + 1 } +func (rr *KX) copy() RR { return &KX{*rr.Hdr.copyHeader(), rr.Preference, rr.Exchanger} } + +func (rr *KX) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + + " " + sprintName(rr.Exchanger) +} + +type TA struct { + Hdr RR_Header + KeyTag uint16 + Algorithm uint8 + DigestType uint8 + Digest string `dns:"hex"` +} + +func (rr *TA) Header() *RR_Header { return &rr.Hdr } +func (rr *TA) len() int { return rr.Hdr.len() + 4 + len(rr.Digest)/2 } +func (rr *TA) copy() RR { + return &TA{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest} +} + +func (rr *TA) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.KeyTag)) + + " " + strconv.Itoa(int(rr.Algorithm)) + + " " + strconv.Itoa(int(rr.DigestType)) + + " " + strings.ToUpper(rr.Digest) +} + +type TALINK struct { + Hdr RR_Header + PreviousName string `dns:"domain-name"` + NextName string `dns:"domain-name"` +} + +func (rr *TALINK) Header() *RR_Header { return &rr.Hdr } +func (rr *TALINK) copy() RR { return &TALINK{*rr.Hdr.copyHeader(), rr.PreviousName, rr.NextName} } +func (rr *TALINK) len() int { return rr.Hdr.len() + len(rr.PreviousName) + len(rr.NextName) + 2 } + +func (rr *TALINK) String() string { + return rr.Hdr.String() + + " " + sprintName(rr.PreviousName) + " " + sprintName(rr.NextName) +} + +type SSHFP struct { + Hdr RR_Header + Algorithm uint8 + Type uint8 + FingerPrint string `dns:"hex"` +} + +func (rr *SSHFP) Header() *RR_Header { return &rr.Hdr } +func (rr *SSHFP) len() int { return rr.Hdr.len() + 2 + len(rr.FingerPrint)/2 } +func (rr *SSHFP) copy() RR { + return &SSHFP{*rr.Hdr.copyHeader(), rr.Algorithm, rr.Type, rr.FingerPrint} +} + +func (rr *SSHFP) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Algorithm)) + + " " + strconv.Itoa(int(rr.Type)) + + " " + strings.ToUpper(rr.FingerPrint) +} + +type IPSECKEY struct { + Hdr RR_Header + Precedence uint8 + GatewayType uint8 + Algorithm uint8 + Gateway string `dns:"ipseckey"` + PublicKey string `dns:"base64"` +} + +func (rr *IPSECKEY) Header() *RR_Header { return &rr.Hdr } +func (rr *IPSECKEY) copy() RR { + return &IPSECKEY{*rr.Hdr.copyHeader(), rr.Precedence, rr.GatewayType, rr.Algorithm, rr.Gateway, rr.PublicKey} +} + +func (rr *IPSECKEY) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Precedence)) + + " " + strconv.Itoa(int(rr.GatewayType)) + + " " + strconv.Itoa(int(rr.Algorithm)) + + " " + rr.Gateway + + " " + rr.PublicKey +} + +func (rr *IPSECKEY) len() int { + return rr.Hdr.len() + 3 + len(rr.Gateway) + 1 + + base64.StdEncoding.DecodedLen(len(rr.PublicKey)) +} + +type DNSKEY struct { + Hdr RR_Header + Flags uint16 + Protocol uint8 + Algorithm uint8 + PublicKey string `dns:"base64"` +} + +func (rr *DNSKEY) Header() *RR_Header { return &rr.Hdr } +func (rr *DNSKEY) len() int { + return rr.Hdr.len() + 4 + base64.StdEncoding.DecodedLen(len(rr.PublicKey)) +} +func (rr *DNSKEY) copy() RR { + return &DNSKEY{*rr.Hdr.copyHeader(), rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey} +} + +func (rr *DNSKEY) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Flags)) + + " " + strconv.Itoa(int(rr.Protocol)) + + " " + strconv.Itoa(int(rr.Algorithm)) + + " " + rr.PublicKey +} + +type RKEY struct { + Hdr RR_Header + Flags uint16 + Protocol uint8 + Algorithm uint8 + PublicKey string `dns:"base64"` +} + +func (rr *RKEY) Header() *RR_Header { return &rr.Hdr } +func (rr *RKEY) len() int { return rr.Hdr.len() + 4 + base64.StdEncoding.DecodedLen(len(rr.PublicKey)) } +func (rr *RKEY) copy() RR { + return &RKEY{*rr.Hdr.copyHeader(), rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey} +} + +func (rr *RKEY) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Flags)) + + " " + strconv.Itoa(int(rr.Protocol)) + + " " + strconv.Itoa(int(rr.Algorithm)) + + " " + rr.PublicKey +} + +type NSAP struct { + Hdr RR_Header + Length uint8 + Nsap string +} + +func (rr *NSAP) Header() *RR_Header { return &rr.Hdr } +func (rr *NSAP) copy() RR { return &NSAP{*rr.Hdr.copyHeader(), rr.Length, rr.Nsap} } +func (rr *NSAP) String() string { return rr.Hdr.String() + strconv.Itoa(int(rr.Length)) + " " + rr.Nsap } +func (rr *NSAP) len() int { return rr.Hdr.len() + 1 + len(rr.Nsap) + 1 } + +type NSAPPTR struct { + Hdr RR_Header + Ptr string `dns:"domain-name"` +} + +func (rr *NSAPPTR) Header() *RR_Header { return &rr.Hdr } +func (rr *NSAPPTR) copy() RR { return &NSAPPTR{*rr.Hdr.copyHeader(), rr.Ptr} } +func (rr *NSAPPTR) String() string { return rr.Hdr.String() + sprintName(rr.Ptr) } +func (rr *NSAPPTR) len() int { return rr.Hdr.len() + len(rr.Ptr) } + +type NSEC3 struct { + Hdr RR_Header + Hash uint8 + Flags uint8 + Iterations uint16 + SaltLength uint8 + Salt string `dns:"size-hex"` + HashLength uint8 + NextDomain string `dns:"size-base32"` + TypeBitMap []uint16 `dns:"nsec"` +} + +func (rr *NSEC3) Header() *RR_Header { return &rr.Hdr } +func (rr *NSEC3) copy() RR { + cp := make([]uint16, len(rr.TypeBitMap), cap(rr.TypeBitMap)) + copy(cp, rr.TypeBitMap) + return &NSEC3{*rr.Hdr.copyHeader(), rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt, rr.HashLength, rr.NextDomain, cp} +} + +func (rr *NSEC3) String() string { + s := rr.Hdr.String() + s += strconv.Itoa(int(rr.Hash)) + + " " + strconv.Itoa(int(rr.Flags)) + + " " + strconv.Itoa(int(rr.Iterations)) + + " " + saltToString(rr.Salt) + + " " + rr.NextDomain + for i := 0; i < len(rr.TypeBitMap); i++ { + s += " " + Type(rr.TypeBitMap[i]).String() + } + return s +} + +func (rr *NSEC3) len() int { + l := rr.Hdr.len() + 6 + len(rr.Salt)/2 + 1 + len(rr.NextDomain) + 1 + lastwindow := uint32(2 ^ 32 + 1) + for _, t := range rr.TypeBitMap { + window := t / 256 + if uint32(window) != lastwindow { + l += 1 + 32 + } + lastwindow = uint32(window) + } + return l +} + +type NSEC3PARAM struct { + Hdr RR_Header + Hash uint8 + Flags uint8 + Iterations uint16 + SaltLength uint8 + Salt string `dns:"hex"` +} + +func (rr *NSEC3PARAM) Header() *RR_Header { return &rr.Hdr } +func (rr *NSEC3PARAM) len() int { return rr.Hdr.len() + 2 + 4 + 1 + len(rr.Salt)/2 } +func (rr *NSEC3PARAM) copy() RR { + return &NSEC3PARAM{*rr.Hdr.copyHeader(), rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt} +} + +func (rr *NSEC3PARAM) String() string { + s := rr.Hdr.String() + s += strconv.Itoa(int(rr.Hash)) + + " " + strconv.Itoa(int(rr.Flags)) + + " " + strconv.Itoa(int(rr.Iterations)) + + " " + saltToString(rr.Salt) + return s +} + +type TKEY struct { + Hdr RR_Header + Algorithm string `dns:"domain-name"` + Inception uint32 + Expiration uint32 + Mode uint16 + Error uint16 + KeySize uint16 + Key string + OtherLen uint16 + OtherData string +} + +func (rr *TKEY) Header() *RR_Header { return &rr.Hdr } +func (rr *TKEY) copy() RR { + return &TKEY{*rr.Hdr.copyHeader(), rr.Algorithm, rr.Inception, rr.Expiration, rr.Mode, rr.Error, rr.KeySize, rr.Key, rr.OtherLen, rr.OtherData} +} + +func (rr *TKEY) String() string { + // It has no presentation format + return "" +} + +func (rr *TKEY) len() int { + return rr.Hdr.len() + len(rr.Algorithm) + 1 + 4 + 4 + 6 + + len(rr.Key) + 2 + len(rr.OtherData) +} + +// RFC3597 represents an unknown/generic RR. +type RFC3597 struct { + Hdr RR_Header + Rdata string `dns:"hex"` +} + +func (rr *RFC3597) Header() *RR_Header { return &rr.Hdr } +func (rr *RFC3597) copy() RR { return &RFC3597{*rr.Hdr.copyHeader(), rr.Rdata} } +func (rr *RFC3597) len() int { return rr.Hdr.len() + len(rr.Rdata)/2 + 2 } + +func (rr *RFC3597) String() string { + s := rr.Hdr.String() + s += "\\# " + strconv.Itoa(len(rr.Rdata)/2) + " " + rr.Rdata + return s +} + +type URI struct { + Hdr RR_Header + Priority uint16 + Weight uint16 + Target []string `dns:"txt"` +} + +func (rr *URI) Header() *RR_Header { return &rr.Hdr } +func (rr *URI) copy() RR { + cp := make([]string, len(rr.Target), cap(rr.Target)) + copy(cp, rr.Target) + return &URI{*rr.Hdr.copyHeader(), rr.Weight, rr.Priority, cp} +} + +func (rr *URI) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Priority)) + + " " + strconv.Itoa(int(rr.Weight)) + sprintTxt(rr.Target) +} + +func (rr *URI) len() int { + l := rr.Hdr.len() + 4 + for _, t := range rr.Target { + l += len(t) + 1 + } + return l +} + +type DHCID struct { + Hdr RR_Header + Digest string `dns:"base64"` +} + +func (rr *DHCID) Header() *RR_Header { return &rr.Hdr } +func (rr *DHCID) copy() RR { return &DHCID{*rr.Hdr.copyHeader(), rr.Digest} } +func (rr *DHCID) String() string { return rr.Hdr.String() + rr.Digest } +func (rr *DHCID) len() int { return rr.Hdr.len() + base64.StdEncoding.DecodedLen(len(rr.Digest)) } + +type TLSA struct { + Hdr RR_Header + Usage uint8 + Selector uint8 + MatchingType uint8 + Certificate string `dns:"hex"` +} + +func (rr *TLSA) Header() *RR_Header { return &rr.Hdr } +func (rr *TLSA) len() int { return rr.Hdr.len() + 3 + len(rr.Certificate)/2 } + +func (rr *TLSA) copy() RR { + return &TLSA{*rr.Hdr.copyHeader(), rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate} +} + +func (rr *TLSA) String() string { + return rr.Hdr.String() + + " " + strconv.Itoa(int(rr.Usage)) + + " " + strconv.Itoa(int(rr.Selector)) + + " " + strconv.Itoa(int(rr.MatchingType)) + + " " + rr.Certificate +} + +type HIP struct { + Hdr RR_Header + HitLength uint8 + PublicKeyAlgorithm uint8 + PublicKeyLength uint16 + Hit string `dns:"hex"` + PublicKey string `dns:"base64"` + RendezvousServers []string `dns:"domain-name"` +} + +func (rr *HIP) Header() *RR_Header { return &rr.Hdr } +func (rr *HIP) copy() RR { + cp := make([]string, len(rr.RendezvousServers), cap(rr.RendezvousServers)) + copy(cp, rr.RendezvousServers) + return &HIP{*rr.Hdr.copyHeader(), rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, cp} +} + +func (rr *HIP) String() string { + s := rr.Hdr.String() + + " " + strconv.Itoa(int(rr.PublicKeyAlgorithm)) + + " " + rr.Hit + + " " + rr.PublicKey + for _, d := range rr.RendezvousServers { + s += " " + sprintName(d) + } + return s +} + +func (rr *HIP) len() int { + l := rr.Hdr.len() + 4 + + len(rr.Hit)/2 + + base64.StdEncoding.DecodedLen(len(rr.PublicKey)) + for _, d := range rr.RendezvousServers { + l += len(d) + 1 + } + return l +} + +type NINFO struct { + Hdr RR_Header + ZSData []string `dns:"txt"` +} + +func (rr *NINFO) Header() *RR_Header { return &rr.Hdr } +func (rr *NINFO) copy() RR { + cp := make([]string, len(rr.ZSData), cap(rr.ZSData)) + copy(cp, rr.ZSData) + return &NINFO{*rr.Hdr.copyHeader(), cp} +} + +func (rr *NINFO) String() string { return rr.Hdr.String() + sprintTxt(rr.ZSData) } + +func (rr *NINFO) len() int { + l := rr.Hdr.len() + for _, t := range rr.ZSData { + l += len(t) + 1 + } + return l +} + +type WKS struct { + Hdr RR_Header + Address net.IP `dns:"a"` + Protocol uint8 + BitMap []uint16 `dns:"wks"` +} + +func (rr *WKS) Header() *RR_Header { return &rr.Hdr } +func (rr *WKS) len() int { return rr.Hdr.len() + net.IPv4len + 1 } + +func (rr *WKS) copy() RR { + cp := make([]uint16, len(rr.BitMap), cap(rr.BitMap)) + copy(cp, rr.BitMap) + return &WKS{*rr.Hdr.copyHeader(), copyIP(rr.Address), rr.Protocol, cp} +} + +func (rr *WKS) String() (s string) { + s = rr.Hdr.String() + if rr.Address != nil { + s += rr.Address.String() + } + for i := 0; i < len(rr.BitMap); i++ { + // should lookup the port + s += " " + strconv.Itoa(int(rr.BitMap[i])) + } + return s +} + +type NID struct { + Hdr RR_Header + Preference uint16 + NodeID uint64 +} + +func (rr *NID) Header() *RR_Header { return &rr.Hdr } +func (rr *NID) copy() RR { return &NID{*rr.Hdr.copyHeader(), rr.Preference, rr.NodeID} } +func (rr *NID) len() int { return rr.Hdr.len() + 2 + 8 } + +func (rr *NID) String() string { + s := rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + node := fmt.Sprintf("%0.16x", rr.NodeID) + s += " " + node[0:4] + ":" + node[4:8] + ":" + node[8:12] + ":" + node[12:16] + return s +} + +type L32 struct { + Hdr RR_Header + Preference uint16 + Locator32 net.IP `dns:"a"` +} + +func (rr *L32) Header() *RR_Header { return &rr.Hdr } +func (rr *L32) copy() RR { return &L32{*rr.Hdr.copyHeader(), rr.Preference, copyIP(rr.Locator32)} } +func (rr *L32) len() int { return rr.Hdr.len() + net.IPv4len } + +func (rr *L32) String() string { + if rr.Locator32 == nil { + return rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + } + return rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + + " " + rr.Locator32.String() +} + +type L64 struct { + Hdr RR_Header + Preference uint16 + Locator64 uint64 +} + +func (rr *L64) Header() *RR_Header { return &rr.Hdr } +func (rr *L64) copy() RR { return &L64{*rr.Hdr.copyHeader(), rr.Preference, rr.Locator64} } +func (rr *L64) len() int { return rr.Hdr.len() + 2 + 8 } + +func (rr *L64) String() string { + s := rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + node := fmt.Sprintf("%0.16X", rr.Locator64) + s += " " + node[0:4] + ":" + node[4:8] + ":" + node[8:12] + ":" + node[12:16] + return s +} + +type LP struct { + Hdr RR_Header + Preference uint16 + Fqdn string `dns:"domain-name"` +} + +func (rr *LP) Header() *RR_Header { return &rr.Hdr } +func (rr *LP) copy() RR { return &LP{*rr.Hdr.copyHeader(), rr.Preference, rr.Fqdn} } +func (rr *LP) len() int { return rr.Hdr.len() + 2 + len(rr.Fqdn) + 1 } + +func (rr *LP) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Preference)) + " " + sprintName(rr.Fqdn) +} + +type EUI48 struct { + Hdr RR_Header + Address uint64 `dns:"uint48"` +} + +func (rr *EUI48) Header() *RR_Header { return &rr.Hdr } +func (rr *EUI48) copy() RR { return &EUI48{*rr.Hdr.copyHeader(), rr.Address} } +func (rr *EUI48) String() string { return rr.Hdr.String() + euiToString(rr.Address, 48) } +func (rr *EUI48) len() int { return rr.Hdr.len() + 6 } + +type EUI64 struct { + Hdr RR_Header + Address uint64 +} + +func (rr *EUI64) Header() *RR_Header { return &rr.Hdr } +func (rr *EUI64) copy() RR { return &EUI64{*rr.Hdr.copyHeader(), rr.Address} } +func (rr *EUI64) String() string { return rr.Hdr.String() + euiToString(rr.Address, 64) } +func (rr *EUI64) len() int { return rr.Hdr.len() + 8 } + +// Support in incomplete - just handle it as unknown record +/* +type CAA struct { + Hdr RR_Header + Flag uint8 + Tag string + Value string `dns:"octet"` +} + +func (rr *CAA) Header() *RR_Header { return &rr.Hdr } +func (rr *CAA) copy() RR { return &CAA{*rr.Hdr.copyHeader(), rr.Flag, rr.Tag, rr.Value} } +func (rr *CAA) len() int { return rr.Hdr.len() + 1 + len(rr.Tag) + 1 + len(rr.Value) } + +func (rr *CAA) String() string { + s := rr.Hdr.String() + strconv.FormatInt(int64(rr.Flag), 10) + " " + rr.Tag + s += strconv.QuoteToASCII(rr.Value) + return s +} +*/ + +type UID struct { + Hdr RR_Header + Uid uint32 +} + +func (rr *UID) Header() *RR_Header { return &rr.Hdr } +func (rr *UID) copy() RR { return &UID{*rr.Hdr.copyHeader(), rr.Uid} } +func (rr *UID) String() string { return rr.Hdr.String() + strconv.FormatInt(int64(rr.Uid), 10) } +func (rr *UID) len() int { return rr.Hdr.len() + 4 } + +type GID struct { + Hdr RR_Header + Gid uint32 +} + +func (rr *GID) Header() *RR_Header { return &rr.Hdr } +func (rr *GID) copy() RR { return &GID{*rr.Hdr.copyHeader(), rr.Gid} } +func (rr *GID) String() string { return rr.Hdr.String() + strconv.FormatInt(int64(rr.Gid), 10) } +func (rr *GID) len() int { return rr.Hdr.len() + 4 } + +type UINFO struct { + Hdr RR_Header + Uinfo string +} + +func (rr *UINFO) Header() *RR_Header { return &rr.Hdr } +func (rr *UINFO) copy() RR { return &UINFO{*rr.Hdr.copyHeader(), rr.Uinfo} } +func (rr *UINFO) String() string { return rr.Hdr.String() + sprintTxt([]string{rr.Uinfo}) } +func (rr *UINFO) len() int { return rr.Hdr.len() + len(rr.Uinfo) + 1 } + +type EID struct { + Hdr RR_Header + Endpoint string `dns:"hex"` +} + +func (rr *EID) Header() *RR_Header { return &rr.Hdr } +func (rr *EID) copy() RR { return &EID{*rr.Hdr.copyHeader(), rr.Endpoint} } +func (rr *EID) String() string { return rr.Hdr.String() + strings.ToUpper(rr.Endpoint) } +func (rr *EID) len() int { return rr.Hdr.len() + len(rr.Endpoint)/2 } + +type NIMLOC struct { + Hdr RR_Header + Locator string `dns:"hex"` +} + +func (rr *NIMLOC) Header() *RR_Header { return &rr.Hdr } +func (rr *NIMLOC) copy() RR { return &NIMLOC{*rr.Hdr.copyHeader(), rr.Locator} } +func (rr *NIMLOC) String() string { return rr.Hdr.String() + strings.ToUpper(rr.Locator) } +func (rr *NIMLOC) len() int { return rr.Hdr.len() + len(rr.Locator)/2 } + +type OPENPGPKEY struct { + Hdr RR_Header + PublicKey string `dns:"base64"` +} + +func (rr *OPENPGPKEY) Header() *RR_Header { return &rr.Hdr } +func (rr *OPENPGPKEY) copy() RR { return &OPENPGPKEY{*rr.Hdr.copyHeader(), rr.PublicKey} } +func (rr *OPENPGPKEY) String() string { return rr.Hdr.String() + rr.PublicKey } +func (rr *OPENPGPKEY) len() int { + return rr.Hdr.len() + base64.StdEncoding.DecodedLen(len(rr.PublicKey)) +} + +// TimeToString translates the RRSIG's incep. and expir. times to the +// string representation used when printing the record. +// It takes serial arithmetic (RFC 1982) into account. +func TimeToString(t uint32) string { + mod := ((int64(t) - time.Now().Unix()) / year68) - 1 + if mod < 0 { + mod = 0 + } + ti := time.Unix(int64(t)-(mod*year68), 0).UTC() + return ti.Format("20060102150405") +} + +// StringToTime translates the RRSIG's incep. and expir. times from +// string values like "20110403154150" to an 32 bit integer. +// It takes serial arithmetic (RFC 1982) into account. +func StringToTime(s string) (uint32, error) { + t, e := time.Parse("20060102150405", s) + if e != nil { + return 0, e + } + mod := (t.Unix() / year68) - 1 + if mod < 0 { + mod = 0 + } + return uint32(t.Unix() - (mod * year68)), nil +} + +// saltToString converts a NSECX salt to uppercase and +// returns "-" when it is empty +func saltToString(s string) string { + if len(s) == 0 { + return "-" + } + return strings.ToUpper(s) +} + +func cmToString(mantissa, exponent uint8) string { + switch exponent { + case 0, 1: + if exponent == 1 { + mantissa *= 10 + } + return fmt.Sprintf("%.02f", float32(mantissa)) + default: + s := fmt.Sprintf("%d", mantissa) + for i := uint8(0); i < exponent-2; i++ { + s += "0" + } + return s + } + panic("dns: not reached") +} + +func euiToString(eui uint64, bits int) (hex string) { + switch bits { + case 64: + hex = fmt.Sprintf("%16.16x", eui) + hex = hex[0:2] + "-" + hex[2:4] + "-" + hex[4:6] + "-" + hex[6:8] + + "-" + hex[8:10] + "-" + hex[10:12] + "-" + hex[12:14] + "-" + hex[14:16] + case 48: + hex = fmt.Sprintf("%12.12x", eui) + hex = hex[0:2] + "-" + hex[2:4] + "-" + hex[4:6] + "-" + hex[6:8] + + "-" + hex[8:10] + "-" + hex[10:12] + } + return +} + +// copyIP returns a copy of ip. +func copyIP(ip net.IP) net.IP { + p := make(net.IP, len(ip)) + copy(p, ip) + return p +} + +// Map of constructors for each RR type. +var typeToRR = map[uint16]func() RR{ + TypeA: func() RR { return new(A) }, + TypeAAAA: func() RR { return new(AAAA) }, + TypeAFSDB: func() RR { return new(AFSDB) }, + // TypeCAA: func() RR { return new(CAA) }, + TypeCDS: func() RR { return new(CDS) }, + TypeCERT: func() RR { return new(CERT) }, + TypeCNAME: func() RR { return new(CNAME) }, + TypeDHCID: func() RR { return new(DHCID) }, + TypeDLV: func() RR { return new(DLV) }, + TypeDNAME: func() RR { return new(DNAME) }, + TypeDNSKEY: func() RR { return new(DNSKEY) }, + TypeDS: func() RR { return new(DS) }, + TypeEUI48: func() RR { return new(EUI48) }, + TypeEUI64: func() RR { return new(EUI64) }, + TypeGID: func() RR { return new(GID) }, + TypeGPOS: func() RR { return new(GPOS) }, + TypeEID: func() RR { return new(EID) }, + TypeHINFO: func() RR { return new(HINFO) }, + TypeHIP: func() RR { return new(HIP) }, + TypeKX: func() RR { return new(KX) }, + TypeL32: func() RR { return new(L32) }, + TypeL64: func() RR { return new(L64) }, + TypeLOC: func() RR { return new(LOC) }, + TypeLP: func() RR { return new(LP) }, + TypeMB: func() RR { return new(MB) }, + TypeMD: func() RR { return new(MD) }, + TypeMF: func() RR { return new(MF) }, + TypeMG: func() RR { return new(MG) }, + TypeMINFO: func() RR { return new(MINFO) }, + TypeMR: func() RR { return new(MR) }, + TypeMX: func() RR { return new(MX) }, + TypeNAPTR: func() RR { return new(NAPTR) }, + TypeNID: func() RR { return new(NID) }, + TypeNINFO: func() RR { return new(NINFO) }, + TypeNIMLOC: func() RR { return new(NIMLOC) }, + TypeNS: func() RR { return new(NS) }, + TypeNSAP: func() RR { return new(NSAP) }, + TypeNSAPPTR: func() RR { return new(NSAPPTR) }, + TypeNSEC3: func() RR { return new(NSEC3) }, + TypeNSEC3PARAM: func() RR { return new(NSEC3PARAM) }, + TypeNSEC: func() RR { return new(NSEC) }, + TypeOPENPGPKEY: func() RR { return new(OPENPGPKEY) }, + TypeOPT: func() RR { return new(OPT) }, + TypePTR: func() RR { return new(PTR) }, + TypeRKEY: func() RR { return new(RKEY) }, + TypeRP: func() RR { return new(RP) }, + TypePX: func() RR { return new(PX) }, + TypeRRSIG: func() RR { return new(RRSIG) }, + TypeRT: func() RR { return new(RT) }, + TypeSOA: func() RR { return new(SOA) }, + TypeSPF: func() RR { return new(SPF) }, + TypeSRV: func() RR { return new(SRV) }, + TypeSSHFP: func() RR { return new(SSHFP) }, + TypeTA: func() RR { return new(TA) }, + TypeTALINK: func() RR { return new(TALINK) }, + TypeTKEY: func() RR { return new(TKEY) }, + TypeTLSA: func() RR { return new(TLSA) }, + TypeTSIG: func() RR { return new(TSIG) }, + TypeTXT: func() RR { return new(TXT) }, + TypeUID: func() RR { return new(UID) }, + TypeUINFO: func() RR { return new(UINFO) }, + TypeURI: func() RR { return new(URI) }, + TypeWKS: func() RR { return new(WKS) }, + TypeX25: func() RR { return new(X25) }, +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/udp.go b/Godeps/_workspace/src/github.com/miekg/dns/udp.go new file mode 100644 index 0000000000..0342543bf7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/udp.go @@ -0,0 +1,55 @@ +// +build !windows + +package dns + +import ( + "net" + "syscall" +) + +type sessionUDP struct { + raddr *net.UDPAddr + context []byte +} + +func (s *sessionUDP) RemoteAddr() net.Addr { return s.raddr } + +// setUDPSocketOptions sets the UDP socket options. +// This function is implemented on a per platform basis. See udp_*.go for more details +func setUDPSocketOptions(conn *net.UDPConn) error { + sa, err := getUDPSocketName(conn) + if err != nil { + return err + } + switch sa.(type) { + case *syscall.SockaddrInet6: + v6only, err := getUDPSocketOptions6Only(conn) + if err != nil { + return err + } + setUDPSocketOptions6(conn) + if !v6only { + setUDPSocketOptions4(conn) + } + case *syscall.SockaddrInet4: + setUDPSocketOptions4(conn) + } + return nil +} + +// readFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a +// net.UDPAddr. +func readFromSessionUDP(conn *net.UDPConn, b []byte) (int, *sessionUDP, error) { + oob := make([]byte, 40) + n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob) + if err != nil { + return n, nil, err + } + return n, &sessionUDP{raddr, oob[:oobn]}, err +} + +// writeToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *sessionUDP instead of a net.Addr. +func writeToSessionUDP(conn *net.UDPConn, b []byte, session *sessionUDP) (int, error) { + n, _, err := conn.WriteMsgUDP(b, session.context, session.raddr) + return n, err +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/udp_linux.go b/Godeps/_workspace/src/github.com/miekg/dns/udp_linux.go new file mode 100644 index 0000000000..7a107857e1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/udp_linux.go @@ -0,0 +1,63 @@ +// +build linux + +package dns + +// See: +// * http://stackoverflow.com/questions/3062205/setting-the-source-ip-for-a-udp-socket and +// * http://blog.powerdns.com/2012/10/08/on-binding-datagram-udp-sockets-to-the-any-addresses/ +// +// Why do we need this: When listening on 0.0.0.0 with UDP so kernel decides what is the outgoing +// interface, this might not always be the correct one. This code will make sure the egress +// packet's interface matched the ingress' one. + +import ( + "net" + "syscall" +) + +// setUDPSocketOptions4 prepares the v4 socket for sessions. +func setUDPSocketOptions4(conn *net.UDPConn) error { + file, err := conn.File() + if err != nil { + return err + } + if err := syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1); err != nil { + return err + } + return nil +} + +// setUDPSocketOptions6 prepares the v6 socket for sessions. +func setUDPSocketOptions6(conn *net.UDPConn) error { + file, err := conn.File() + if err != nil { + return err + } + if err := syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_RECVPKTINFO, 1); err != nil { + return err + } + return nil +} + +// getUDPSocketOption6Only return true if the socket is v6 only and false when it is v4/v6 combined +// (dualstack). +func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) { + file, err := conn.File() + if err != nil { + return false, err + } + // dual stack. See http://stackoverflow.com/questions/1618240/how-to-support-both-ipv4-and-ipv6-connections + v6only, err := syscall.GetsockoptInt(int(file.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY) + if err != nil { + return false, err + } + return v6only == 1, nil +} + +func getUDPSocketName(conn *net.UDPConn) (syscall.Sockaddr, error) { + file, err := conn.File() + if err != nil { + return nil, err + } + return syscall.Getsockname(int(file.Fd())) +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/udp_other.go b/Godeps/_workspace/src/github.com/miekg/dns/udp_other.go new file mode 100644 index 0000000000..c38dd3e7f0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/udp_other.go @@ -0,0 +1,17 @@ +// +build !linux + +package dns + +import ( + "net" + "syscall" +) + +// These do nothing. See udp_linux.go for an example of how to implement this. + +// We tried to adhire to some kind of naming scheme. + +func setUDPSocketOptions4(conn *net.UDPConn) error { return nil } +func setUDPSocketOptions6(conn *net.UDPConn) error { return nil } +func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) { return false, nil } +func getUDPSocketName(conn *net.UDPConn) (syscall.Sockaddr, error) { return nil, nil } diff --git a/Godeps/_workspace/src/github.com/miekg/dns/udp_windows.go b/Godeps/_workspace/src/github.com/miekg/dns/udp_windows.go new file mode 100644 index 0000000000..4c48723b56 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/udp_windows.go @@ -0,0 +1,34 @@ +// +build windows + +package dns + +import "net" + +type sessionUDP struct { + raddr *net.UDPAddr +} + +// readFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a +// net.UDPAddr. +func readFromSessionUDP(conn *net.UDPConn, b []byte) (int, *sessionUDP, error) { + n, raddr, err := conn.ReadFrom(b) + if err != nil { + return n, nil, err + } + session := &sessionUDP{raddr.(*net.UDPAddr)} + return n, session, err +} + +// writeToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *sessionUDP instead of a net.Addr. +func writeToSessionUDP(conn *net.UDPConn, b []byte, session *sessionUDP) (int, error) { + n, err := conn.WriteTo(b, session.raddr) + return n, err +} + +func (s *sessionUDP) RemoteAddr() net.Addr { return s.raddr } + +// setUDPSocketOptions sets the UDP socket options. +// This function is implemented on a per platform basis. See udp_*.go for more details +func setUDPSocketOptions(conn *net.UDPConn) error { + return nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/update.go b/Godeps/_workspace/src/github.com/miekg/dns/update.go new file mode 100644 index 0000000000..053e42245f --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/update.go @@ -0,0 +1,132 @@ +// DYNAMIC UPDATES +// +// Dynamic updates reuses the DNS message format, but renames three of +// the sections. Question is Zone, Answer is Prerequisite, Authority is +// Update, only the Additional is not renamed. See RFC 2136 for the gory details. +// +// You can set a rather complex set of rules for the existence of absence of +// certain resource records or names in a zone to specify if resource records +// should be added or removed. The table from RFC 2136 supplemented with the Go +// DNS function shows which functions exist to specify the prerequisites. +// +// 3.2.4 - Table Of Metavalues Used In Prerequisite Section +// +// CLASS TYPE RDATA Meaning Function +// -------------------------------------------------------------- +// ANY ANY empty Name is in use dns.NameUsed +// ANY rrset empty RRset exists (value indep) dns.RRsetUsed +// NONE ANY empty Name is not in use dns.NameNotUsed +// NONE rrset empty RRset does not exist dns.RRsetNotUsed +// zone rrset rr RRset exists (value dep) dns.Used +// +// The prerequisite section can also be left empty. +// If you have decided on the prerequisites you can tell what RRs should +// be added or deleted. The next table shows the options you have and +// what functions to call. +// +// 3.4.2.6 - Table Of Metavalues Used In Update Section +// +// CLASS TYPE RDATA Meaning Function +// --------------------------------------------------------------- +// ANY ANY empty Delete all RRsets from name dns.RemoveName +// ANY rrset empty Delete an RRset dns.RemoveRRset +// NONE rrset rr Delete an RR from RRset dns.Remove +// zone rrset rr Add to an RRset dns.Insert +// +package dns + +// NameUsed sets the RRs in the prereq section to +// "Name is in use" RRs. RFC 2136 section 2.4.4. +func (u *Msg) NameUsed(rr []RR) { + u.Answer = make([]RR, len(rr)) + for i, r := range rr { + u.Answer[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassANY}} + } +} + +// NameNotUsed sets the RRs in the prereq section to +// "Name is in not use" RRs. RFC 2136 section 2.4.5. +func (u *Msg) NameNotUsed(rr []RR) { + u.Answer = make([]RR, len(rr)) + for i, r := range rr { + u.Answer[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassNONE}} + } +} + +// Used sets the RRs in the prereq section to +// "RRset exists (value dependent -- with rdata)" RRs. RFC 2136 section 2.4.2. +func (u *Msg) Used(rr []RR) { + if len(u.Question) == 0 { + panic("dns: empty question section") + } + u.Answer = make([]RR, len(rr)) + for i, r := range rr { + u.Answer[i] = r + u.Answer[i].Header().Class = u.Question[0].Qclass + } +} + +// RRsetUsed sets the RRs in the prereq section to +// "RRset exists (value independent -- no rdata)" RRs. RFC 2136 section 2.4.1. +func (u *Msg) RRsetUsed(rr []RR) { + u.Answer = make([]RR, len(rr)) + for i, r := range rr { + u.Answer[i] = r + u.Answer[i].Header().Class = ClassANY + u.Answer[i].Header().Ttl = 0 + u.Answer[i].Header().Rdlength = 0 + } +} + +// RRsetNotUsed sets the RRs in the prereq section to +// "RRset does not exist" RRs. RFC 2136 section 2.4.3. +func (u *Msg) RRsetNotUsed(rr []RR) { + u.Answer = make([]RR, len(rr)) + for i, r := range rr { + u.Answer[i] = r + u.Answer[i].Header().Class = ClassNONE + u.Answer[i].Header().Rdlength = 0 + u.Answer[i].Header().Ttl = 0 + } +} + +// Insert creates a dynamic update packet that adds an complete RRset, see RFC 2136 section 2.5.1. +func (u *Msg) Insert(rr []RR) { + if len(u.Question) == 0 { + panic("dns: empty question section") + } + u.Ns = make([]RR, len(rr)) + for i, r := range rr { + u.Ns[i] = r + u.Ns[i].Header().Class = u.Question[0].Qclass + } +} + +// RemoveRRset creates a dynamic update packet that deletes an RRset, see RFC 2136 section 2.5.2. +func (u *Msg) RemoveRRset(rr []RR) { + u.Ns = make([]RR, len(rr)) + for i, r := range rr { + u.Ns[i] = r + u.Ns[i].Header().Class = ClassANY + u.Ns[i].Header().Rdlength = 0 + u.Ns[i].Header().Ttl = 0 + } +} + +// RemoveName creates a dynamic update packet that deletes all RRsets of a name, see RFC 2136 section 2.5.3 +func (u *Msg) RemoveName(rr []RR) { + u.Ns = make([]RR, len(rr)) + for i, r := range rr { + u.Ns[i] = &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassANY}} + } +} + +// Remove creates a dynamic update packet deletes RR from the RRSset, see RFC 2136 section 2.5.4 +func (u *Msg) Remove(rr []RR) { + u.Ns = make([]RR, len(rr)) + for i, r := range rr { + u.Ns[i] = r + u.Ns[i].Header().Class = ClassNONE + u.Ns[i].Header().Ttl = 0 + } +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/xfr.go b/Godeps/_workspace/src/github.com/miekg/dns/xfr.go new file mode 100644 index 0000000000..57bfb1676f --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/xfr.go @@ -0,0 +1,236 @@ +package dns + +import ( + "time" +) + +// Envelope is used when doing a zone transfer with a remote server. +type Envelope struct { + RR []RR // The set of RRs in the answer section of the xfr reply message. + Error error // If something went wrong, this contains the error. +} + +// A Transfer defines parameters that are used during a zone transfer. +type Transfer struct { + *Conn + DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 + TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be fully qualified + tsigTimersOnly bool +} + +// Think we need to away to stop the transfer + +// In performs an incoming transfer with the server in a. +func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) { + timeout := dnsTimeout + if t.DialTimeout != 0 { + timeout = t.DialTimeout + } + t.Conn, err = DialTimeout("tcp", a, timeout) + if err != nil { + return nil, err + } + if err := t.WriteMsg(q); err != nil { + return nil, err + } + env = make(chan *Envelope) + go func() { + if q.Question[0].Qtype == TypeAXFR { + go t.inAxfr(q.Id, env) + return + } + if q.Question[0].Qtype == TypeIXFR { + go t.inIxfr(q.Id, env) + return + } + }() + return env, nil +} + +func (t *Transfer) inAxfr(id uint16, c chan *Envelope) { + first := true + defer t.Close() + defer close(c) + timeout := dnsTimeout + if t.ReadTimeout != 0 { + timeout = t.ReadTimeout + } + for { + t.Conn.SetReadDeadline(time.Now().Add(timeout)) + in, err := t.ReadMsg() + if err != nil { + c <- &Envelope{nil, err} + return + } + if id != in.Id { + c <- &Envelope{in.Answer, ErrId} + return + } + if first { + if !isSOAFirst(in) { + c <- &Envelope{in.Answer, ErrSoa} + return + } + first = !first + // only one answer that is SOA, receive more + if len(in.Answer) == 1 { + t.tsigTimersOnly = true + c <- &Envelope{in.Answer, nil} + continue + } + } + + if !first { + t.tsigTimersOnly = true // Subsequent envelopes use this. + if isSOALast(in) { + c <- &Envelope{in.Answer, nil} + return + } + c <- &Envelope{in.Answer, nil} + } + } + panic("dns: not reached") +} + +func (t *Transfer) inIxfr(id uint16, c chan *Envelope) { + serial := uint32(0) // The first serial seen is the current server serial + first := true + defer t.Close() + defer close(c) + timeout := dnsTimeout + if t.ReadTimeout != 0 { + timeout = t.ReadTimeout + } + for { + t.SetReadDeadline(time.Now().Add(timeout)) + in, err := t.ReadMsg() + if err != nil { + c <- &Envelope{in.Answer, err} + return + } + if id != in.Id { + c <- &Envelope{in.Answer, ErrId} + return + } + if first { + // A single SOA RR signals "no changes" + if len(in.Answer) == 1 && isSOAFirst(in) { + c <- &Envelope{in.Answer, nil} + return + } + + // Check if the returned answer is ok + if !isSOAFirst(in) { + c <- &Envelope{in.Answer, ErrSoa} + return + } + // This serial is important + serial = in.Answer[0].(*SOA).Serial + first = !first + } + + // Now we need to check each message for SOA records, to see what we need to do + if !first { + t.tsigTimersOnly = true + // If the last record in the IXFR contains the servers' SOA, we should quit + if v, ok := in.Answer[len(in.Answer)-1].(*SOA); ok { + if v.Serial == serial { + c <- &Envelope{in.Answer, nil} + return + } + } + c <- &Envelope{in.Answer, nil} + } + } +} + +// Out performs an outgoing transfer with the client connecting in w. +// Basic use pattern: +// +// ch := make(chan *dns.Envelope) +// tr := new(dns.Transfer) +// tr.Out(w, r, ch) +// c <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}} +// close(ch) +// w.Hijack() +// // w.Close() // Client closes connection +// +// The server is responsible for sending the correct sequence of RRs through the +// channel ch. +func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error { + r := new(Msg) + // Compress? + r.SetReply(q) + r.Authoritative = true + + go func() { + for x := range ch { + // assume it fits TODO(miek): fix + r.Answer = append(r.Answer, x.RR...) + if err := w.WriteMsg(r); err != nil { + return + } + } + w.TsigTimersOnly(true) + r.Answer = nil + }() + return nil +} + +// ReadMsg reads a message from the transfer connection t. +func (t *Transfer) ReadMsg() (*Msg, error) { + m := new(Msg) + p := make([]byte, MaxMsgSize) + n, err := t.Read(p) + if err != nil && n == 0 { + return nil, err + } + p = p[:n] + if err := m.Unpack(p); err != nil { + return nil, err + } + if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) + } + return m, err +} + +// WriteMsg writes a message through the transfer connection t. +func (t *Transfer) WriteMsg(m *Msg) (err error) { + var out []byte + if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return ErrSecret + } + out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) + } else { + out, err = m.Pack() + } + if err != nil { + return err + } + if _, err = t.Write(out); err != nil { + return err + } + return nil +} + +func isSOAFirst(in *Msg) bool { + if len(in.Answer) > 0 { + return in.Answer[0].Header().Rrtype == TypeSOA + } + return false +} + +func isSOALast(in *Msg) bool { + if len(in.Answer) > 0 { + return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA + } + return false +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/zgenerate.go b/Godeps/_workspace/src/github.com/miekg/dns/zgenerate.go new file mode 100644 index 0000000000..7f1183e874 --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/zgenerate.go @@ -0,0 +1,157 @@ +package dns + +import ( + "fmt" + "strconv" + "strings" +) + +// Parse the $GENERATE statement as used in BIND9 zones. +// See http://www.zytrax.com/books/dns/ch8/generate.html for instance. +// We are called after '$GENERATE '. After which we expect: +// * the range (12-24/2) +// * lhs (ownername) +// * [[ttl][class]] +// * type +// * rhs (rdata) +// But we are lazy here, only the range is parsed *all* occurences +// of $ after that are interpreted. +// Any error are returned as a string value, the empty string signals +// "no error". +func generate(l lex, c chan lex, t chan *Token, o string) string { + step := 1 + if i := strings.IndexAny(l.token, "/"); i != -1 { + if i+1 == len(l.token) { + return "bad step in $GENERATE range" + } + if s, e := strconv.Atoi(l.token[i+1:]); e != nil { + return "bad step in $GENERATE range" + } else { + if s < 0 { + return "bad step in $GENERATE range" + } + step = s + } + l.token = l.token[:i] + } + sx := strings.SplitN(l.token, "-", 2) + if len(sx) != 2 { + return "bad start-stop in $GENERATE range" + } + start, err := strconv.Atoi(sx[0]) + if err != nil { + return "bad start in $GENERATE range" + } + end, err := strconv.Atoi(sx[1]) + if err != nil { + return "bad stop in $GENERATE range" + } + if end < 0 || start < 0 || end <= start { + return "bad range in $GENERATE range" + } + + <-c // _BLANK + // Create a complete new string, which we then parse again. + s := "" +BuildRR: + l = <-c + if l.value != _NEWLINE && l.value != _EOF { + s += l.token + goto BuildRR + } + for i := start; i <= end; i += step { + var ( + escape bool + dom string + mod string + err string + offset int + ) + + for j := 0; j < len(s); j++ { // No 'range' because we need to jump around + switch s[j] { + case '\\': + if escape { + dom += "\\" + escape = false + continue + } + escape = true + case '$': + mod = "%d" + offset = 0 + if escape { + dom += "$" + escape = false + continue + } + escape = false + if j+1 >= len(s) { // End of the string + dom += fmt.Sprintf(mod, i+offset) + continue + } else { + if s[j+1] == '$' { + dom += "$" + j++ + continue + } + } + // Search for { and } + if s[j+1] == '{' { // Modifier block + sep := strings.Index(s[j+2:], "}") + if sep == -1 { + return "bad modifier in $GENERATE" + } + mod, offset, err = modToPrintf(s[j+2 : j+2+sep]) + if err != "" { + return err + } + j += 2 + sep // Jump to it + } + dom += fmt.Sprintf(mod, i+offset) + default: + if escape { // Pretty useless here + escape = false + continue + } + dom += string(s[j]) + } + } + // Re-parse the RR and send it on the current channel t + rx, e := NewRR("$ORIGIN " + o + "\n" + dom) + if e != nil { + return e.(*ParseError).err + } + t <- &Token{RR: rx} + // Its more efficient to first built the rrlist and then parse it in + // one go! But is this a problem? + } + return "" +} + +// Convert a $GENERATE modifier 0,0,d to something Printf can deal with. +func modToPrintf(s string) (string, int, string) { + xs := strings.SplitN(s, ",", 3) + if len(xs) != 3 { + return "", 0, "bad modifier in $GENERATE" + } + // xs[0] is offset, xs[1] is width, xs[2] is base + if xs[2] != "o" && xs[2] != "d" && xs[2] != "x" && xs[2] != "X" { + return "", 0, "bad base in $GENERATE" + } + offset, err := strconv.Atoi(xs[0]) + if err != nil { + return "", 0, "bad offset in $GENERATE" + } + width, err := strconv.Atoi(xs[1]) + if err != nil { + return "", offset, "bad width in $GENERATE" + } + switch { + case width < 0: + return "", offset, "bad width in $GENERATE" + case width == 0: + return "%" + xs[1] + xs[2], offset, "" + } + return "%0" + xs[1] + xs[2], offset, "" +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/zscan.go b/Godeps/_workspace/src/github.com/miekg/dns/zscan.go new file mode 100644 index 0000000000..41a328e42d --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/zscan.go @@ -0,0 +1,956 @@ +package dns + +import ( + "io" + "log" + "os" + "strconv" + "strings" +) + +type debugging bool + +const debug debugging = false + +func (d debugging) Printf(format string, args ...interface{}) { + if d { + log.Printf(format, args...) + } +} + +const maxTok = 2048 // Largest token we can return. +const maxUint16 = 1<<16 - 1 + +// Tokinize a RFC 1035 zone file. The tokenizer will normalize it: +// * Add ownernames if they are left blank; +// * Suppress sequences of spaces; +// * Make each RR fit on one line (_NEWLINE is send as last) +// * Handle comments: ; +// * Handle braces - anywhere. +const ( + // Zonefile + _EOF = iota + _STRING + _BLANK + _QUOTE + _NEWLINE + _RRTYPE + _OWNER + _CLASS + _DIRORIGIN // $ORIGIN + _DIRTTL // $TTL + _DIRINCLUDE // $INCLUDE + _DIRGENERATE // $GENERATE + + // Privatekey file + _VALUE + _KEY + + _EXPECT_OWNER_DIR // Ownername + _EXPECT_OWNER_BL // Whitespace after the ownername + _EXPECT_ANY // Expect rrtype, ttl or class + _EXPECT_ANY_NOCLASS // Expect rrtype or ttl + _EXPECT_ANY_NOCLASS_BL // The whitespace after _EXPECT_ANY_NOCLASS + _EXPECT_ANY_NOTTL // Expect rrtype or class + _EXPECT_ANY_NOTTL_BL // Whitespace after _EXPECT_ANY_NOTTL + _EXPECT_RRTYPE // Expect rrtype + _EXPECT_RRTYPE_BL // Whitespace BEFORE rrtype + _EXPECT_RDATA // The first element of the rdata + _EXPECT_DIRTTL_BL // Space after directive $TTL + _EXPECT_DIRTTL // Directive $TTL + _EXPECT_DIRORIGIN_BL // Space after directive $ORIGIN + _EXPECT_DIRORIGIN // Directive $ORIGIN + _EXPECT_DIRINCLUDE_BL // Space after directive $INCLUDE + _EXPECT_DIRINCLUDE // Directive $INCLUDE + _EXPECT_DIRGENERATE // Directive $GENERATE + _EXPECT_DIRGENERATE_BL // Space after directive $GENERATE +) + +// ParseError is a parsing error. It contains the parse error and the location in the io.Reader +// where the error occured. +type ParseError struct { + file string + err string + lex lex +} + +func (e *ParseError) Error() (s string) { + if e.file != "" { + s = e.file + ": " + } + s += "dns: " + e.err + ": " + strconv.QuoteToASCII(e.lex.token) + " at line: " + + strconv.Itoa(e.lex.line) + ":" + strconv.Itoa(e.lex.column) + return +} + +type lex struct { + token string // text of the token + tokenUpper string // uppercase text of the token + length int // lenght of the token + err bool // when true, token text has lexer error + value uint8 // value: _STRING, _BLANK, etc. + line int // line in the file + column int // column in the file + torc uint16 // type or class as parsed in the lexer, we only need to look this up in the grammar + comment string // any comment text seen +} + +// *Tokens are returned when a zone file is parsed. +type Token struct { + RR // the scanned resource record when error is not nil + Error *ParseError // when an error occured, this has the error specifics + Comment string // a potential comment positioned after the RR and on the same line +} + +// NewRR reads the RR contained in the string s. Only the first RR is returned. +// The class defaults to IN and TTL defaults to 3600. The full zone file +// syntax like $TTL, $ORIGIN, etc. is supported. +// All fields of the returned RR are set, except RR.Header().Rdlength which is set to 0. +func NewRR(s string) (RR, error) { + if s[len(s)-1] != '\n' { // We need a closing newline + return ReadRR(strings.NewReader(s+"\n"), "") + } + return ReadRR(strings.NewReader(s), "") +} + +// ReadRR reads the RR contained in q. +// See NewRR for more documentation. +func ReadRR(q io.Reader, filename string) (RR, error) { + r := <-parseZoneHelper(q, ".", filename, 1) + if r.Error != nil { + return nil, r.Error + } + return r.RR, nil +} + +// ParseZone reads a RFC 1035 style one from r. It returns *Tokens on the +// returned channel, which consist out the parsed RR, a potential comment or an error. +// If there is an error the RR is nil. The string file is only used +// in error reporting. The string origin is used as the initial origin, as +// if the file would start with: $ORIGIN origin . +// The directives $INCLUDE, $ORIGIN, $TTL and $GENERATE are supported. +// The channel t is closed by ParseZone when the end of r is reached. +// +// Basic usage pattern when reading from a string (z) containing the +// zone data: +// +// for x := range dns.ParseZone(strings.NewReader(z), "", "") { +// if x.Error != nil { +// // Do something with x.RR +// } +// } +// +// Comments specified after an RR (and on the same line!) are returned too: +// +// foo. IN A 10.0.0.1 ; this is a comment +// +// The text "; this is comment" is returned in Token.Comment . Comments inside the +// RR are discarded. Comments on a line by themselves are discarded too. +func ParseZone(r io.Reader, origin, file string) chan *Token { + return parseZoneHelper(r, origin, file, 10000) +} + +func parseZoneHelper(r io.Reader, origin, file string, chansize int) chan *Token { + t := make(chan *Token, chansize) + go parseZone(r, origin, file, t, 0) + return t +} + +func parseZone(r io.Reader, origin, f string, t chan *Token, include int) { + defer func() { + if include == 0 { + close(t) + } + }() + s := scanInit(r) + c := make(chan lex, 1000) + // Start the lexer + go zlexer(s, c) + // 6 possible beginnings of a line, _ is a space + // 0. _RRTYPE -> all omitted until the rrtype + // 1. _OWNER _ _RRTYPE -> class/ttl omitted + // 2. _OWNER _ _STRING _ _RRTYPE -> class omitted + // 3. _OWNER _ _STRING _ _CLASS _ _RRTYPE -> ttl/class + // 4. _OWNER _ _CLASS _ _RRTYPE -> ttl omitted + // 5. _OWNER _ _CLASS _ _STRING _ _RRTYPE -> class/ttl (reversed) + // After detecting these, we know the _RRTYPE so we can jump to functions + // handling the rdata for each of these types. + + if origin == "" { + origin = "." + } + origin = Fqdn(origin) + if _, ok := IsDomainName(origin); !ok { + t <- &Token{Error: &ParseError{f, "bad initial origin name", lex{}}} + return + } + + st := _EXPECT_OWNER_DIR // initial state + var h RR_Header + var defttl uint32 = defaultTtl + var prevName string + for l := range c { + // Lexer spotted an error already + if l.err == true { + t <- &Token{Error: &ParseError{f, l.token, l}} + return + + } + switch st { + case _EXPECT_OWNER_DIR: + // We can also expect a directive, like $TTL or $ORIGIN + h.Ttl = defttl + h.Class = ClassINET + switch l.value { + case _NEWLINE: // Empty line + st = _EXPECT_OWNER_DIR + case _OWNER: + h.Name = l.token + if l.token[0] == '@' { + h.Name = origin + prevName = h.Name + st = _EXPECT_OWNER_BL + break + } + if h.Name[l.length-1] != '.' { + h.Name = appendOrigin(h.Name, origin) + } + _, ok := IsDomainName(l.token) + if !ok { + t <- &Token{Error: &ParseError{f, "bad owner name", l}} + return + } + prevName = h.Name + st = _EXPECT_OWNER_BL + case _DIRTTL: + st = _EXPECT_DIRTTL_BL + case _DIRORIGIN: + st = _EXPECT_DIRORIGIN_BL + case _DIRINCLUDE: + st = _EXPECT_DIRINCLUDE_BL + case _DIRGENERATE: + st = _EXPECT_DIRGENERATE_BL + case _RRTYPE: // Everthing has been omitted, this is the first thing on the line + h.Name = prevName + h.Rrtype = l.torc + st = _EXPECT_RDATA + case _CLASS: // First thing on the line is the class + h.Name = prevName + h.Class = l.torc + st = _EXPECT_ANY_NOCLASS_BL + case _BLANK: + // Discard, can happen when there is nothing on the + // line except the RR type + case _STRING: // First thing on the is the ttl + if ttl, ok := stringToTtl(l.token); !ok { + t <- &Token{Error: &ParseError{f, "not a TTL", l}} + return + } else { + h.Ttl = ttl + // Don't about the defttl, we should take the $TTL value + // defttl = ttl + } + st = _EXPECT_ANY_NOTTL_BL + + default: + t <- &Token{Error: &ParseError{f, "syntax error at beginning", l}} + return + } + case _EXPECT_DIRINCLUDE_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank after $INCLUDE-directive", l}} + return + } + st = _EXPECT_DIRINCLUDE + case _EXPECT_DIRINCLUDE: + if l.value != _STRING { + t <- &Token{Error: &ParseError{f, "expecting $INCLUDE value, not this...", l}} + return + } + neworigin := origin // There may be optionally a new origin set after the filename, if not use current one + l := <-c + switch l.value { + case _BLANK: + l := <-c + if l.value == _STRING { + if _, ok := IsDomainName(l.token); !ok { + t <- &Token{Error: &ParseError{f, "bad origin name", l}} + return + } + // a new origin is specified. + if l.token[l.length-1] != '.' { + if origin != "." { // Prevent .. endings + neworigin = l.token + "." + origin + } else { + neworigin = l.token + origin + } + } else { + neworigin = l.token + } + } + case _NEWLINE, _EOF: + // Ok + default: + t <- &Token{Error: &ParseError{f, "garbage after $INCLUDE", l}} + return + } + // Start with the new file + r1, e1 := os.Open(l.token) + if e1 != nil { + t <- &Token{Error: &ParseError{f, "failed to open `" + l.token + "'", l}} + return + } + if include+1 > 7 { + t <- &Token{Error: &ParseError{f, "too deeply nested $INCLUDE", l}} + return + } + parseZone(r1, l.token, neworigin, t, include+1) + st = _EXPECT_OWNER_DIR + case _EXPECT_DIRTTL_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank after $TTL-directive", l}} + return + } + st = _EXPECT_DIRTTL + case _EXPECT_DIRTTL: + if l.value != _STRING { + t <- &Token{Error: &ParseError{f, "expecting $TTL value, not this...", l}} + return + } + if e, _ := slurpRemainder(c, f); e != nil { + t <- &Token{Error: e} + return + } + if ttl, ok := stringToTtl(l.token); !ok { + t <- &Token{Error: &ParseError{f, "expecting $TTL value, not this...", l}} + return + } else { + defttl = ttl + } + st = _EXPECT_OWNER_DIR + case _EXPECT_DIRORIGIN_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank after $ORIGIN-directive", l}} + return + } + st = _EXPECT_DIRORIGIN + case _EXPECT_DIRORIGIN: + if l.value != _STRING { + t <- &Token{Error: &ParseError{f, "expecting $ORIGIN value, not this...", l}} + return + } + if e, _ := slurpRemainder(c, f); e != nil { + t <- &Token{Error: e} + } + if _, ok := IsDomainName(l.token); !ok { + t <- &Token{Error: &ParseError{f, "bad origin name", l}} + return + } + if l.token[l.length-1] != '.' { + if origin != "." { // Prevent .. endings + origin = l.token + "." + origin + } else { + origin = l.token + origin + } + } else { + origin = l.token + } + st = _EXPECT_OWNER_DIR + case _EXPECT_DIRGENERATE_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank after $GENERATE-directive", l}} + return + } + st = _EXPECT_DIRGENERATE + case _EXPECT_DIRGENERATE: + if l.value != _STRING { + t <- &Token{Error: &ParseError{f, "expecting $GENERATE value, not this...", l}} + return + } + if e := generate(l, c, t, origin); e != "" { + t <- &Token{Error: &ParseError{f, e, l}} + return + } + st = _EXPECT_OWNER_DIR + case _EXPECT_OWNER_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank after owner", l}} + return + } + st = _EXPECT_ANY + case _EXPECT_ANY: + switch l.value { + case _RRTYPE: + h.Rrtype = l.torc + st = _EXPECT_RDATA + case _CLASS: + h.Class = l.torc + st = _EXPECT_ANY_NOCLASS_BL + case _STRING: // TTL is this case + if ttl, ok := stringToTtl(l.token); !ok { + t <- &Token{Error: &ParseError{f, "not a TTL", l}} + return + } else { + h.Ttl = ttl + // defttl = ttl // don't set the defttl here + } + st = _EXPECT_ANY_NOTTL_BL + default: + t <- &Token{Error: &ParseError{f, "expecting RR type, TTL or class, not this...", l}} + return + } + case _EXPECT_ANY_NOCLASS_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank before class", l}} + return + } + st = _EXPECT_ANY_NOCLASS + case _EXPECT_ANY_NOTTL_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank before TTL", l}} + return + } + st = _EXPECT_ANY_NOTTL + case _EXPECT_ANY_NOTTL: + switch l.value { + case _CLASS: + h.Class = l.torc + st = _EXPECT_RRTYPE_BL + case _RRTYPE: + h.Rrtype = l.torc + st = _EXPECT_RDATA + default: + t <- &Token{Error: &ParseError{f, "expecting RR type or class, not this...", l}} + return + } + case _EXPECT_ANY_NOCLASS: + switch l.value { + case _STRING: // TTL + if ttl, ok := stringToTtl(l.token); !ok { + t <- &Token{Error: &ParseError{f, "not a TTL", l}} + return + } else { + h.Ttl = ttl + // defttl = ttl // don't set the def ttl anymore + } + st = _EXPECT_RRTYPE_BL + case _RRTYPE: + h.Rrtype = l.torc + st = _EXPECT_RDATA + default: + t <- &Token{Error: &ParseError{f, "expecting RR type or TTL, not this...", l}} + return + } + case _EXPECT_RRTYPE_BL: + if l.value != _BLANK { + t <- &Token{Error: &ParseError{f, "no blank before RR type", l}} + return + } + st = _EXPECT_RRTYPE + case _EXPECT_RRTYPE: + if l.value != _RRTYPE { + t <- &Token{Error: &ParseError{f, "unknown RR type", l}} + return + } + h.Rrtype = l.torc + st = _EXPECT_RDATA + case _EXPECT_RDATA: + r, e, c1 := setRR(h, c, origin, f) + if e != nil { + // If e.lex is nil than we have encounter a unknown RR type + // in that case we substitute our current lex token + if e.lex.token == "" && e.lex.value == 0 { + e.lex = l // Uh, dirty + } + t <- &Token{Error: e} + return + } + t <- &Token{RR: r, Comment: c1} + st = _EXPECT_OWNER_DIR + } + } + // If we get here, we and the h.Rrtype is still zero, we haven't parsed anything, this + // is not an error, because an empty zone file is still a zone file. +} + +// zlexer scans the sourcefile and returns tokens on the channel c. +func zlexer(s *scan, c chan lex) { + var l lex + str := make([]byte, maxTok) // Should be enough for any token + stri := 0 // Offset in str (0 means empty) + com := make([]byte, maxTok) // Hold comment text + comi := 0 + quote := false + escape := false + space := false + commt := false + rrtype := false + owner := true + brace := 0 + x, err := s.tokenText() + defer close(c) + for err == nil { + l.column = s.position.Column + l.line = s.position.Line + if stri > maxTok { + l.token = "token length insufficient for parsing" + l.err = true + debug.Printf("[%+v]", l.token) + c <- l + return + } + if comi > maxTok { + l.token = "comment length insufficient for parsing" + l.err = true + debug.Printf("[%+v]", l.token) + c <- l + return + } + + switch x { + case ' ', '\t': + if escape { + escape = false + str[stri] = x + stri++ + break + } + if quote { + // Inside quotes this is legal + str[stri] = x + stri++ + break + } + if commt { + com[comi] = x + comi++ + break + } + if stri == 0 { + // Space directly in the beginning, handled in the grammar + } else if owner { + // If we have a string and its the first, make it an owner + l.value = _OWNER + l.token = string(str[:stri]) + l.tokenUpper = strings.ToUpper(l.token) + l.length = stri + // escape $... start with a \ not a $, so this will work + switch l.tokenUpper { + case "$TTL": + l.value = _DIRTTL + case "$ORIGIN": + l.value = _DIRORIGIN + case "$INCLUDE": + l.value = _DIRINCLUDE + case "$GENERATE": + l.value = _DIRGENERATE + } + debug.Printf("[7 %+v]", l.token) + c <- l + } else { + l.value = _STRING + l.token = string(str[:stri]) + l.tokenUpper = strings.ToUpper(l.token) + l.length = stri + if !rrtype { + if t, ok := StringToType[l.tokenUpper]; ok { + l.value = _RRTYPE + l.torc = t + rrtype = true + } else { + if strings.HasPrefix(l.tokenUpper, "TYPE") { + if t, ok := typeToInt(l.token); !ok { + l.token = "unknown RR type" + l.err = true + c <- l + return + } else { + l.value = _RRTYPE + l.torc = t + } + } + } + if t, ok := StringToClass[l.tokenUpper]; ok { + l.value = _CLASS + l.torc = t + } else { + if strings.HasPrefix(l.tokenUpper, "CLASS") { + if t, ok := classToInt(l.token); !ok { + l.token = "unknown class" + l.err = true + c <- l + return + } else { + l.value = _CLASS + l.torc = t + } + } + } + } + debug.Printf("[6 %+v]", l.token) + c <- l + } + stri = 0 + // I reverse space stuff here + if !space && !commt { + l.value = _BLANK + l.token = " " + l.length = 1 + debug.Printf("[5 %+v]", l.token) + c <- l + } + owner = false + space = true + case ';': + if escape { + escape = false + str[stri] = x + stri++ + break + } + if quote { + // Inside quotes this is legal + str[stri] = x + stri++ + break + } + if stri > 0 { + l.value = _STRING + l.token = string(str[:stri]) + l.length = stri + debug.Printf("[4 %+v]", l.token) + c <- l + stri = 0 + } + commt = true + com[comi] = ';' + comi++ + case '\r': + escape = false + if quote { + str[stri] = x + stri++ + break + } + // discard if outside of quotes + case '\n': + escape = false + // Escaped newline + if quote { + str[stri] = x + stri++ + break + } + // inside quotes this is legal + if commt { + // Reset a comment + commt = false + rrtype = false + stri = 0 + // If not in a brace this ends the comment AND the RR + if brace == 0 { + owner = true + owner = true + l.value = _NEWLINE + l.token = "\n" + l.length = 1 + l.comment = string(com[:comi]) + debug.Printf("[3 %+v %+v]", l.token, l.comment) + c <- l + l.comment = "" + comi = 0 + break + } + com[comi] = ' ' // convert newline to space + comi++ + break + } + + if brace == 0 { + // If there is previous text, we should output it here + if stri != 0 { + l.value = _STRING + l.token = string(str[:stri]) + l.tokenUpper = strings.ToUpper(l.token) + + l.length = stri + if !rrtype { + if t, ok := StringToType[l.tokenUpper]; ok { + l.value = _RRTYPE + l.torc = t + rrtype = true + } + } + debug.Printf("[2 %+v]", l.token) + c <- l + } + l.value = _NEWLINE + l.token = "\n" + l.length = 1 + debug.Printf("[1 %+v]", l.token) + c <- l + stri = 0 + commt = false + rrtype = false + owner = true + comi = 0 + } + case '\\': + // comments do not get escaped chars, everything is copied + if commt { + com[comi] = x + comi++ + break + } + // something already escaped must be in string + if escape { + str[stri] = x + stri++ + escape = false + break + } + // something escaped outside of string gets added to string + str[stri] = x + stri++ + escape = true + case '"': + if commt { + com[comi] = x + comi++ + break + } + if escape { + str[stri] = x + stri++ + escape = false + break + } + space = false + // send previous gathered text and the quote + if stri != 0 { + l.value = _STRING + l.token = string(str[:stri]) + l.length = stri + + debug.Printf("[%+v]", l.token) + c <- l + stri = 0 + } + + // send quote itself as separate token + l.value = _QUOTE + l.token = "\"" + l.length = 1 + c <- l + quote = !quote + case '(', ')': + if commt { + com[comi] = x + comi++ + break + } + if escape { + str[stri] = x + stri++ + escape = false + break + } + if quote { + str[stri] = x + stri++ + break + } + switch x { + case ')': + brace-- + if brace < 0 { + l.token = "extra closing brace" + l.err = true + debug.Printf("[%+v]", l.token) + c <- l + return + } + case '(': + brace++ + } + default: + escape = false + if commt { + com[comi] = x + comi++ + break + } + str[stri] = x + stri++ + space = false + } + x, err = s.tokenText() + } + if stri > 0 { + // Send remainder + l.token = string(str[:stri]) + l.length = stri + l.value = _STRING + debug.Printf("[%+v]", l.token) + c <- l + } +} + +// Extract the class number from CLASSxx +func classToInt(token string) (uint16, bool) { + class, ok := strconv.Atoi(token[5:]) + if ok != nil || class > maxUint16 { + return 0, false + } + return uint16(class), true +} + +// Extract the rr number from TYPExxx +func typeToInt(token string) (uint16, bool) { + typ, ok := strconv.Atoi(token[4:]) + if ok != nil || typ > maxUint16 { + return 0, false + } + return uint16(typ), true +} + +// Parse things like 2w, 2m, etc, Return the time in seconds. +func stringToTtl(token string) (uint32, bool) { + s := uint32(0) + i := uint32(0) + for _, c := range token { + switch c { + case 's', 'S': + s += i + i = 0 + case 'm', 'M': + s += i * 60 + i = 0 + case 'h', 'H': + s += i * 60 * 60 + i = 0 + case 'd', 'D': + s += i * 60 * 60 * 24 + i = 0 + case 'w', 'W': + s += i * 60 * 60 * 24 * 7 + i = 0 + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + i *= 10 + i += uint32(c) - '0' + default: + return 0, false + } + } + return s + i, true +} + +// Parse LOC records' [.][mM] into a +// mantissa exponent format. Token should contain the entire +// string (i.e. no spaces allowed) +func stringToCm(token string) (e, m uint8, ok bool) { + if token[len(token)-1] == 'M' || token[len(token)-1] == 'm' { + token = token[0 : len(token)-1] + } + s := strings.SplitN(token, ".", 2) + var meters, cmeters, val int + var err error + switch len(s) { + case 2: + if cmeters, err = strconv.Atoi(s[1]); err != nil { + return + } + fallthrough + case 1: + if meters, err = strconv.Atoi(s[0]); err != nil { + return + } + case 0: + // huh? + return 0, 0, false + } + ok = true + if meters > 0 { + e = 2 + val = meters + } else { + e = 0 + val = cmeters + } + for val > 10 { + e++ + val /= 10 + } + if e > 9 { + ok = false + } + m = uint8(val) + return +} + +func appendOrigin(name, origin string) string { + if origin == "." { + return name + origin + } + return name + "." + origin +} + +// LOC record helper function +func locCheckNorth(token string, latitude uint32) (uint32, bool) { + switch token { + case "n", "N": + return _LOC_EQUATOR + latitude, true + case "s", "S": + return _LOC_EQUATOR - latitude, true + } + return latitude, false +} + +// LOC record helper function +func locCheckEast(token string, longitude uint32) (uint32, bool) { + switch token { + case "e", "E": + return _LOC_EQUATOR + longitude, true + case "w", "W": + return _LOC_EQUATOR - longitude, true + } + return longitude, false +} + +// "Eat" the rest of the "line". Return potential comments +func slurpRemainder(c chan lex, f string) (*ParseError, string) { + l := <-c + com := "" + switch l.value { + case _BLANK: + l = <-c + com = l.comment + if l.value != _NEWLINE && l.value != _EOF { + return &ParseError{f, "garbage after rdata", l}, "" + } + case _NEWLINE: + com = l.comment + case _EOF: + default: + return &ParseError{f, "garbage after rdata", l}, "" + } + return nil, com +} + +// Parse a 64 bit-like ipv6 address: "0014:4fff:ff20:ee64" +// Used for NID and L64 record. +func stringToNodeID(l lex) (uint64, *ParseError) { + if len(l.token) < 19 { + return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} + } + // There must be three colons at fixes postitions, if not its a parse error + if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' { + return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} + } + s := l.token[0:4] + l.token[5:9] + l.token[10:14] + l.token[15:19] + u, e := strconv.ParseUint(s, 16, 64) + if e != nil { + return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l} + } + return u, nil +} diff --git a/Godeps/_workspace/src/github.com/miekg/dns/zscan_rr.go b/Godeps/_workspace/src/github.com/miekg/dns/zscan_rr.go new file mode 100644 index 0000000000..957b9e15ab --- /dev/null +++ b/Godeps/_workspace/src/github.com/miekg/dns/zscan_rr.go @@ -0,0 +1,2178 @@ +package dns + +import ( + "encoding/base64" + "net" + "strconv" + "strings" +) + +type parserFunc struct { + // Func defines the function that parses the tokens and returns the RR + // or an error. The last string contains any comments in the line as + // they returned by the lexer as well. + Func func(h RR_Header, c chan lex, origin string, file string) (RR, *ParseError, string) + // Signals if the RR ending is of variable length, like TXT or records + // that have Hexadecimal or Base64 as their last element in the Rdata. Records + // that have a fixed ending or for instance A, AAAA, SOA and etc. + Variable bool +} + +// Parse the rdata of each rrtype. +// All data from the channel c is either _STRING or _BLANK. +// After the rdata there may come a _BLANK and then a _NEWLINE +// or immediately a _NEWLINE. If this is not the case we flag +// an *ParseError: garbage after rdata. +func setRR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + parserfunc, ok := typeToparserFunc[h.Rrtype] + if ok { + r, e, cm := parserfunc.Func(h, c, o, f) + if parserfunc.Variable { + return r, e, cm + } + if e != nil { + return nil, e, "" + } + e, cm = slurpRemainder(c, f) + if e != nil { + return nil, e, "" + } + return r, nil, cm + } + // RFC3957 RR (Unknown RR handling) + return setRFC3597(h, c, o, f) +} + +// A remainder of the rdata with embedded spaces, return the parsed string (sans the spaces) +// or an error +func endingToString(c chan lex, errstr, f string) (string, *ParseError, string) { + s := "" + l := <-c // _STRING + for l.value != _NEWLINE && l.value != _EOF { + switch l.value { + case _STRING: + s += l.token + case _BLANK: // Ok + default: + return "", &ParseError{f, errstr, l}, "" + } + l = <-c + } + return s, nil, l.comment +} + +// A remainder of the rdata with embedded spaces, return the parsed string slice (sans the spaces) +// or an error +func endingToTxtSlice(c chan lex, errstr, f string) ([]string, *ParseError, string) { + // Get the remaining data until we see a NEWLINE + quote := false + l := <-c + var s []string + switch l.value == _QUOTE { + case true: // A number of quoted string + s = make([]string, 0) + empty := true + for l.value != _NEWLINE && l.value != _EOF { + switch l.value { + case _STRING: + empty = false + s = append(s, l.token) + case _BLANK: + if quote { + // _BLANK can only be seen in between txt parts. + return nil, &ParseError{f, errstr, l}, "" + } + case _QUOTE: + if empty && quote { + s = append(s, "") + } + quote = !quote + empty = true + default: + return nil, &ParseError{f, errstr, l}, "" + } + l = <-c + } + if quote { + return nil, &ParseError{f, errstr, l}, "" + } + case false: // Unquoted text record + s = make([]string, 1) + for l.value != _NEWLINE && l.value != _EOF { + s[0] += l.token + l = <-c + } + } + return s, nil, l.comment +} + +func setA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(A) + rr.Hdr = h + + l := <-c + if l.length == 0 { // Dynamic updates. + return rr, nil, "" + } + rr.A = net.ParseIP(l.token) + if rr.A == nil { + return nil, &ParseError{f, "bad A A", l}, "" + } + return rr, nil, "" +} + +func setAAAA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(AAAA) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + rr.AAAA = net.ParseIP(l.token) + if rr.AAAA == nil { + return nil, &ParseError{f, "bad AAAA AAAA", l}, "" + } + return rr, nil, "" +} + +func setNS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NS) + rr.Hdr = h + + l := <-c + rr.Ns = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Ns = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad NS Ns", l}, "" + } + if rr.Ns[l.length-1] != '.' { + rr.Ns = appendOrigin(rr.Ns, o) + } + return rr, nil, "" +} + +func setPTR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(PTR) + rr.Hdr = h + + l := <-c + rr.Ptr = l.token + if l.length == 0 { // dynamic update rr. + return rr, nil, "" + } + if l.token == "@" { + rr.Ptr = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad PTR Ptr", l}, "" + } + if rr.Ptr[l.length-1] != '.' { + rr.Ptr = appendOrigin(rr.Ptr, o) + } + return rr, nil, "" +} + +func setNSAPPTR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NSAPPTR) + rr.Hdr = h + + l := <-c + rr.Ptr = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Ptr = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad NSAP-PTR Ptr", l}, "" + } + if rr.Ptr[l.length-1] != '.' { + rr.Ptr = appendOrigin(rr.Ptr, o) + } + return rr, nil, "" +} + +func setRP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(RP) + rr.Hdr = h + + l := <-c + rr.Mbox = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Mbox = o + } else { + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad RP Mbox", l}, "" + } + if rr.Mbox[l.length-1] != '.' { + rr.Mbox = appendOrigin(rr.Mbox, o) + } + } + <-c // _BLANK + l = <-c + rr.Txt = l.token + if l.token == "@" { + rr.Txt = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad RP Txt", l}, "" + } + if rr.Txt[l.length-1] != '.' { + rr.Txt = appendOrigin(rr.Txt, o) + } + return rr, nil, "" +} + +func setMR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(MR) + rr.Hdr = h + + l := <-c + rr.Mr = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Mr = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad MR Mr", l}, "" + } + if rr.Mr[l.length-1] != '.' { + rr.Mr = appendOrigin(rr.Mr, o) + } + return rr, nil, "" +} + +func setMB(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(MB) + rr.Hdr = h + + l := <-c + rr.Mb = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Mb = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad MB Mb", l}, "" + } + if rr.Mb[l.length-1] != '.' { + rr.Mb = appendOrigin(rr.Mb, o) + } + return rr, nil, "" +} + +func setMG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(MG) + rr.Hdr = h + + l := <-c + rr.Mg = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Mg = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad MG Mg", l}, "" + } + if rr.Mg[l.length-1] != '.' { + rr.Mg = appendOrigin(rr.Mg, o) + } + return rr, nil, "" +} + +func setHINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(HINFO) + rr.Hdr = h + + l := <-c + rr.Cpu = l.token + <-c // _BLANK + l = <-c // _STRING + rr.Os = l.token + + return rr, nil, "" +} + +func setMINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(MINFO) + rr.Hdr = h + + l := <-c + rr.Rmail = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Rmail = o + } else { + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad MINFO Rmail", l}, "" + } + if rr.Rmail[l.length-1] != '.' { + rr.Rmail = appendOrigin(rr.Rmail, o) + } + } + <-c // _BLANK + l = <-c + rr.Email = l.token + if l.token == "@" { + rr.Email = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad MINFO Email", l}, "" + } + if rr.Email[l.length-1] != '.' { + rr.Email = appendOrigin(rr.Email, o) + } + return rr, nil, "" +} + +func setMF(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(MF) + rr.Hdr = h + + l := <-c + rr.Mf = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Mf = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad MF Mf", l}, "" + } + if rr.Mf[l.length-1] != '.' { + rr.Mf = appendOrigin(rr.Mf, o) + } + return rr, nil, "" +} + +func setMD(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(MD) + rr.Hdr = h + + l := <-c + rr.Md = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Md = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad MD Md", l}, "" + } + if rr.Md[l.length-1] != '.' { + rr.Md = appendOrigin(rr.Md, o) + } + return rr, nil, "" +} + +func setMX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(MX) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad MX Pref", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Mx = l.token + if l.token == "@" { + rr.Mx = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad MX Mx", l}, "" + } + if rr.Mx[l.length-1] != '.' { + rr.Mx = appendOrigin(rr.Mx, o) + } + return rr, nil, "" +} + +func setRT(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(RT) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad RT Preference", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Host = l.token + if l.token == "@" { + rr.Host = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad RT Host", l}, "" + } + if rr.Host[l.length-1] != '.' { + rr.Host = appendOrigin(rr.Host, o) + } + return rr, nil, "" +} + +func setAFSDB(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(AFSDB) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad AFSDB Subtype", l}, "" + } else { + rr.Subtype = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Hostname = l.token + if l.token == "@" { + rr.Hostname = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad AFSDB Hostname", l}, "" + } + if rr.Hostname[l.length-1] != '.' { + rr.Hostname = appendOrigin(rr.Hostname, o) + } + return rr, nil, "" +} + +func setX25(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(X25) + rr.Hdr = h + + l := <-c + rr.PSDNAddress = l.token + return rr, nil, "" +} + +func setKX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(KX) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad KX Pref", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Exchanger = l.token + if l.token == "@" { + rr.Exchanger = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad KX Exchanger", l}, "" + } + if rr.Exchanger[l.length-1] != '.' { + rr.Exchanger = appendOrigin(rr.Exchanger, o) + } + return rr, nil, "" +} + +func setCNAME(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(CNAME) + rr.Hdr = h + + l := <-c + rr.Target = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Target = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad CNAME Target", l}, "" + } + if rr.Target[l.length-1] != '.' { + rr.Target = appendOrigin(rr.Target, o) + } + return rr, nil, "" +} + +func setDNAME(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(DNAME) + rr.Hdr = h + + l := <-c + rr.Target = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Target = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad CNAME Target", l}, "" + } + if rr.Target[l.length-1] != '.' { + rr.Target = appendOrigin(rr.Target, o) + } + return rr, nil, "" +} + +func setSOA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(SOA) + rr.Hdr = h + + l := <-c + rr.Ns = l.token + if l.length == 0 { + return rr, nil, "" + } + <-c // _BLANK + if l.token == "@" { + rr.Ns = o + } else { + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad SOA Ns", l}, "" + } + if rr.Ns[l.length-1] != '.' { + rr.Ns = appendOrigin(rr.Ns, o) + } + } + + l = <-c + rr.Mbox = l.token + if l.token == "@" { + rr.Mbox = o + } else { + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad SOA Mbox", l}, "" + } + if rr.Mbox[l.length-1] != '.' { + rr.Mbox = appendOrigin(rr.Mbox, o) + } + } + <-c // _BLANK + + var ( + v uint32 + ok bool + ) + for i := 0; i < 5; i++ { + l = <-c + if j, e := strconv.Atoi(l.token); e != nil { + if i == 0 { + // Serial should be a number + return nil, &ParseError{f, "bad SOA zone parameter", l}, "" + } + if v, ok = stringToTtl(l.token); !ok { + return nil, &ParseError{f, "bad SOA zone parameter", l}, "" + + } + } else { + v = uint32(j) + } + switch i { + case 0: + rr.Serial = v + <-c // _BLANK + case 1: + rr.Refresh = v + <-c // _BLANK + case 2: + rr.Retry = v + <-c // _BLANK + case 3: + rr.Expire = v + <-c // _BLANK + case 4: + rr.Minttl = v + } + } + return rr, nil, "" +} + +func setSRV(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(SRV) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad SRV Priority", l}, "" + } else { + rr.Priority = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad SRV Weight", l}, "" + } else { + rr.Weight = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad SRV Port", l}, "" + } else { + rr.Port = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Target = l.token + if l.token == "@" { + rr.Target = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad SRV Target", l}, "" + } + if rr.Target[l.length-1] != '.' { + rr.Target = appendOrigin(rr.Target, o) + } + return rr, nil, "" +} + +func setNAPTR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NAPTR) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NAPTR Order", l}, "" + } else { + rr.Order = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NAPTR Preference", l}, "" + } else { + rr.Preference = uint16(i) + } + // Flags + <-c // _BLANK + l = <-c // _QUOTE + if l.value != _QUOTE { + return nil, &ParseError{f, "bad NAPTR Flags", l}, "" + } + l = <-c // Either String or Quote + if l.value == _STRING { + rr.Flags = l.token + l = <-c // _QUOTE + if l.value != _QUOTE { + return nil, &ParseError{f, "bad NAPTR Flags", l}, "" + } + } else if l.value == _QUOTE { + rr.Flags = "" + } else { + return nil, &ParseError{f, "bad NAPTR Flags", l}, "" + } + + // Service + <-c // _BLANK + l = <-c // _QUOTE + if l.value != _QUOTE { + return nil, &ParseError{f, "bad NAPTR Service", l}, "" + } + l = <-c // Either String or Quote + if l.value == _STRING { + rr.Service = l.token + l = <-c // _QUOTE + if l.value != _QUOTE { + return nil, &ParseError{f, "bad NAPTR Service", l}, "" + } + } else if l.value == _QUOTE { + rr.Service = "" + } else { + return nil, &ParseError{f, "bad NAPTR Service", l}, "" + } + + // Regexp + <-c // _BLANK + l = <-c // _QUOTE + if l.value != _QUOTE { + return nil, &ParseError{f, "bad NAPTR Regexp", l}, "" + } + l = <-c // Either String or Quote + if l.value == _STRING { + rr.Regexp = l.token + l = <-c // _QUOTE + if l.value != _QUOTE { + return nil, &ParseError{f, "bad NAPTR Regexp", l}, "" + } + } else if l.value == _QUOTE { + rr.Regexp = "" + } else { + return nil, &ParseError{f, "bad NAPTR Regexp", l}, "" + } + // After quote no space?? + <-c // _BLANK + l = <-c // _STRING + rr.Replacement = l.token + if l.token == "@" { + rr.Replacement = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad NAPTR Replacement", l}, "" + } + if rr.Replacement[l.length-1] != '.' { + rr.Replacement = appendOrigin(rr.Replacement, o) + } + return rr, nil, "" +} + +func setTALINK(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(TALINK) + rr.Hdr = h + + l := <-c + rr.PreviousName = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.PreviousName = o + } else { + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad TALINK PreviousName", l}, "" + } + if rr.PreviousName[l.length-1] != '.' { + rr.PreviousName = appendOrigin(rr.PreviousName, o) + } + } + <-c // _BLANK + l = <-c + rr.NextName = l.token + if l.token == "@" { + rr.NextName = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad TALINK NextName", l}, "" + } + if rr.NextName[l.length-1] != '.' { + rr.NextName = appendOrigin(rr.NextName, o) + } + return rr, nil, "" +} + +func setLOC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(LOC) + rr.Hdr = h + // Non zero defaults for LOC record, see RFC 1876, Section 3. + rr.HorizPre = 165 // 10000 + rr.VertPre = 162 // 10 + rr.Size = 18 // 1 + ok := false + // North + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad LOC Latitude", l}, "" + } else { + rr.Latitude = 1000 * 60 * 60 * uint32(i) + } + <-c // _BLANK + // Either number, 'N' or 'S' + l = <-c + if rr.Latitude, ok = locCheckNorth(l.token, rr.Latitude); ok { + goto East + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad LOC Latitude minutes", l}, "" + } else { + rr.Latitude += 1000 * 60 * uint32(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.ParseFloat(l.token, 32); e != nil { + return nil, &ParseError{f, "bad LOC Latitude seconds", l}, "" + } else { + rr.Latitude += uint32(1000 * i) + } + <-c // _BLANK + // Either number, 'N' or 'S' + l = <-c + if rr.Latitude, ok = locCheckNorth(l.token, rr.Latitude); ok { + goto East + } + // If still alive, flag an error + return nil, &ParseError{f, "bad LOC Latitude North/South", l}, "" + +East: + // East + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad LOC Longitude", l}, "" + } else { + rr.Longitude = 1000 * 60 * 60 * uint32(i) + } + <-c // _BLANK + // Either number, 'E' or 'W' + l = <-c + if rr.Longitude, ok = locCheckEast(l.token, rr.Longitude); ok { + goto Altitude + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad LOC Longitude minutes", l}, "" + } else { + rr.Longitude += 1000 * 60 * uint32(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.ParseFloat(l.token, 32); e != nil { + return nil, &ParseError{f, "bad LOC Longitude seconds", l}, "" + } else { + rr.Longitude += uint32(1000 * i) + } + <-c // _BLANK + // Either number, 'E' or 'W' + l = <-c + if rr.Longitude, ok = locCheckEast(l.token, rr.Longitude); ok { + goto Altitude + } + // If still alive, flag an error + return nil, &ParseError{f, "bad LOC Longitude East/West", l}, "" + +Altitude: + <-c // _BLANK + l = <-c + if l.token[len(l.token)-1] == 'M' || l.token[len(l.token)-1] == 'm' { + l.token = l.token[0 : len(l.token)-1] + } + if i, e := strconv.ParseFloat(l.token, 32); e != nil { + return nil, &ParseError{f, "bad LOC Altitude", l}, "" + } else { + rr.Altitude = uint32(i*100.0 + 10000000.0 + 0.5) + } + + // And now optionally the other values + l = <-c + count := 0 + for l.value != _NEWLINE && l.value != _EOF { + switch l.value { + case _STRING: + switch count { + case 0: // Size + if e, m, ok := stringToCm(l.token); !ok { + return nil, &ParseError{f, "bad LOC Size", l}, "" + } else { + rr.Size = (e & 0x0f) | (m << 4 & 0xf0) + } + case 1: // HorizPre + if e, m, ok := stringToCm(l.token); !ok { + return nil, &ParseError{f, "bad LOC HorizPre", l}, "" + } else { + rr.HorizPre = (e & 0x0f) | (m << 4 & 0xf0) + } + case 2: // VertPre + if e, m, ok := stringToCm(l.token); !ok { + return nil, &ParseError{f, "bad LOC VertPre", l}, "" + } else { + rr.VertPre = (e & 0x0f) | (m << 4 & 0xf0) + } + } + count++ + case _BLANK: + // Ok + default: + return nil, &ParseError{f, "bad LOC Size, HorizPre or VertPre", l}, "" + } + l = <-c + } + return rr, nil, "" +} + +func setHIP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(HIP) + rr.Hdr = h + + // HitLength is not represented + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad HIP PublicKeyAlgorithm", l}, "" + } else { + rr.PublicKeyAlgorithm = uint8(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Hit = l.token // This can not contain spaces, see RFC 5205 Section 6. + rr.HitLength = uint8(len(rr.Hit)) / 2 + + <-c // _BLANK + l = <-c // _STRING + rr.PublicKey = l.token // This cannot contain spaces + rr.PublicKeyLength = uint16(base64.StdEncoding.DecodedLen(len(rr.PublicKey))) + + // RendezvousServers (if any) + l = <-c + xs := make([]string, 0) + for l.value != _NEWLINE && l.value != _EOF { + switch l.value { + case _STRING: + if l.token == "@" { + xs = append(xs, o) + continue + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad HIP RendezvousServers", l}, "" + } + if l.token[l.length-1] != '.' { + l.token = appendOrigin(l.token, o) + } + xs = append(xs, l.token) + case _BLANK: + // Ok + default: + return nil, &ParseError{f, "bad HIP RendezvousServers", l}, "" + } + l = <-c + } + rr.RendezvousServers = xs + return rr, nil, l.comment +} + +func setCERT(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(CERT) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if v, ok := StringToCertType[l.token]; ok { + rr.Type = v + } else if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad CERT Type", l}, "" + } else { + rr.Type = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad CERT KeyTag", l}, "" + } else { + rr.KeyTag = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + if v, ok := StringToAlgorithm[l.token]; ok { + rr.Algorithm = v + } else if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad CERT Algorithm", l}, "" + } else { + rr.Algorithm = uint8(i) + } + s, e, c1 := endingToString(c, "bad CERT Certificate", f) + if e != nil { + return nil, e, c1 + } + rr.Certificate = s + return rr, nil, c1 +} + +func setOPENPGPKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(OPENPGPKEY) + rr.Hdr = h + + s, e, c1 := endingToString(c, "bad OPENPGPKEY PublicKey", f) + if e != nil { + return nil, e, c1 + } + rr.PublicKey = s + return rr, nil, c1 +} + +func setRRSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(RRSIG) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if t, ok := StringToType[l.tokenUpper]; !ok { + if strings.HasPrefix(l.tokenUpper, "TYPE") { + if t, ok = typeToInt(l.tokenUpper); !ok { + return nil, &ParseError{f, "bad RRSIG Typecovered", l}, "" + } else { + rr.TypeCovered = t + } + } else { + return nil, &ParseError{f, "bad RRSIG Typecovered", l}, "" + } + } else { + rr.TypeCovered = t + } + <-c // _BLANK + l = <-c + if i, err := strconv.Atoi(l.token); err != nil { + return nil, &ParseError{f, "bad RRSIG Algorithm", l}, "" + } else { + rr.Algorithm = uint8(i) + } + <-c // _BLANK + l = <-c + if i, err := strconv.Atoi(l.token); err != nil { + return nil, &ParseError{f, "bad RRSIG Labels", l}, "" + } else { + rr.Labels = uint8(i) + } + <-c // _BLANK + l = <-c + if i, err := strconv.Atoi(l.token); err != nil { + return nil, &ParseError{f, "bad RRSIG OrigTtl", l}, "" + } else { + rr.OrigTtl = uint32(i) + } + <-c // _BLANK + l = <-c + if i, err := StringToTime(l.token); err != nil { + // Try to see if all numeric and use it as epoch + if i, err := strconv.ParseInt(l.token, 10, 64); err == nil { + // TODO(miek): error out on > MAX_UINT32, same below + rr.Expiration = uint32(i) + } else { + return nil, &ParseError{f, "bad RRSIG Expiration", l}, "" + } + } else { + rr.Expiration = i + } + <-c // _BLANK + l = <-c + if i, err := StringToTime(l.token); err != nil { + if i, err := strconv.ParseInt(l.token, 10, 64); err == nil { + rr.Inception = uint32(i) + } else { + return nil, &ParseError{f, "bad RRSIG Inception", l}, "" + } + } else { + rr.Inception = i + } + <-c // _BLANK + l = <-c + if i, err := strconv.Atoi(l.token); err != nil { + return nil, &ParseError{f, "bad RRSIG KeyTag", l}, "" + } else { + rr.KeyTag = uint16(i) + } + <-c // _BLANK + l = <-c + rr.SignerName = l.token + if l.token == "@" { + rr.SignerName = o + } else { + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad RRSIG SignerName", l}, "" + } + if rr.SignerName[l.length-1] != '.' { + rr.SignerName = appendOrigin(rr.SignerName, o) + } + } + s, e, c1 := endingToString(c, "bad RRSIG Signature", f) + if e != nil { + return nil, e, c1 + } + rr.Signature = s + return rr, nil, c1 +} + +func setNSEC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NSEC) + rr.Hdr = h + + l := <-c + rr.NextDomain = l.token + if l.length == 0 { + return rr, nil, l.comment + } + if l.token == "@" { + rr.NextDomain = o + } else { + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad NSEC NextDomain", l}, "" + } + if rr.NextDomain[l.length-1] != '.' { + rr.NextDomain = appendOrigin(rr.NextDomain, o) + } + } + + rr.TypeBitMap = make([]uint16, 0) + var ( + k uint16 + ok bool + ) + l = <-c + for l.value != _NEWLINE && l.value != _EOF { + switch l.value { + case _BLANK: + // Ok + case _STRING: + if k, ok = StringToType[l.tokenUpper]; !ok { + if k, ok = typeToInt(l.tokenUpper); !ok { + return nil, &ParseError{f, "bad NSEC TypeBitMap", l}, "" + } + } + rr.TypeBitMap = append(rr.TypeBitMap, k) + default: + return nil, &ParseError{f, "bad NSEC TypeBitMap", l}, "" + } + l = <-c + } + return rr, nil, l.comment +} + +func setNSEC3(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NSEC3) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NSEC3 Hash", l}, "" + } else { + rr.Hash = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NSEC3 Flags", l}, "" + } else { + rr.Flags = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NSEC3 Iterations", l}, "" + } else { + rr.Iterations = uint16(i) + } + <-c + l = <-c + if len(l.token) == 0 { + return nil, &ParseError{f, "bad NSEC3 Salt", l}, "" + } + rr.SaltLength = uint8(len(l.token)) / 2 + rr.Salt = l.token + + <-c + l = <-c + rr.HashLength = 20 // Fix for NSEC3 (sha1 160 bits) + rr.NextDomain = l.token + + rr.TypeBitMap = make([]uint16, 0) + var ( + k uint16 + ok bool + ) + l = <-c + for l.value != _NEWLINE && l.value != _EOF { + switch l.value { + case _BLANK: + // Ok + case _STRING: + if k, ok = StringToType[l.tokenUpper]; !ok { + if k, ok = typeToInt(l.tokenUpper); !ok { + return nil, &ParseError{f, "bad NSEC3 TypeBitMap", l}, "" + } + } + rr.TypeBitMap = append(rr.TypeBitMap, k) + default: + return nil, &ParseError{f, "bad NSEC3 TypeBitMap", l}, "" + } + l = <-c + } + return rr, nil, l.comment +} + +func setNSEC3PARAM(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NSEC3PARAM) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NSEC3PARAM Hash", l}, "" + } else { + rr.Hash = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NSEC3PARAM Flags", l}, "" + } else { + rr.Flags = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NSEC3PARAM Iterations", l}, "" + } else { + rr.Iterations = uint16(i) + } + <-c + l = <-c + rr.SaltLength = uint8(len(l.token)) + rr.Salt = l.token + return rr, nil, "" +} + +func setEUI48(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(EUI48) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if l.length != 17 { + return nil, &ParseError{f, "bad EUI48 Address", l}, "" + } + addr := make([]byte, 12) + dash := 0 + for i := 0; i < 10; i += 2 { + addr[i] = l.token[i+dash] + addr[i+1] = l.token[i+1+dash] + dash++ + if l.token[i+1+dash] != '-' { + return nil, &ParseError{f, "bad EUI48 Address", l}, "" + } + } + addr[10] = l.token[15] + addr[11] = l.token[16] + + if i, e := strconv.ParseUint(string(addr), 16, 48); e != nil { + return nil, &ParseError{f, "bad EUI48 Address", l}, "" + } else { + rr.Address = i + } + return rr, nil, "" +} + +func setEUI64(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(EUI64) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if l.length != 23 { + return nil, &ParseError{f, "bad EUI64 Address", l}, "" + } + addr := make([]byte, 16) + dash := 0 + for i := 0; i < 14; i += 2 { + addr[i] = l.token[i+dash] + addr[i+1] = l.token[i+1+dash] + dash++ + if l.token[i+1+dash] != '-' { + return nil, &ParseError{f, "bad EUI64 Address", l}, "" + } + } + addr[14] = l.token[21] + addr[15] = l.token[22] + + if i, e := strconv.ParseUint(string(addr), 16, 64); e != nil { + return nil, &ParseError{f, "bad EUI68 Address", l}, "" + } else { + rr.Address = uint64(i) + } + return rr, nil, "" +} + +func setWKS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(WKS) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + rr.Address = net.ParseIP(l.token) + if rr.Address == nil { + return nil, &ParseError{f, "bad WKS Address", l}, "" + } + + <-c // _BLANK + l = <-c + proto := "tcp" + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad WKS Protocol", l}, "" + } else { + rr.Protocol = uint8(i) + switch rr.Protocol { + case 17: + proto = "udp" + case 6: + proto = "tcp" + default: + return nil, &ParseError{f, "bad WKS Protocol", l}, "" + } + } + + <-c + l = <-c + rr.BitMap = make([]uint16, 0) + var ( + k int + err error + ) + for l.value != _NEWLINE && l.value != _EOF { + switch l.value { + case _BLANK: + // Ok + case _STRING: + if k, err = net.LookupPort(proto, l.token); err != nil { + if i, e := strconv.Atoi(l.token); e != nil { // If a number use that + rr.BitMap = append(rr.BitMap, uint16(i)) + } else { + return nil, &ParseError{f, "bad WKS BitMap", l}, "" + } + } + rr.BitMap = append(rr.BitMap, uint16(k)) + default: + return nil, &ParseError{f, "bad WKS BitMap", l}, "" + } + l = <-c + } + return rr, nil, l.comment +} + +func setSSHFP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(SSHFP) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad SSHFP Algorithm", l}, "" + } else { + rr.Algorithm = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad SSHFP Type", l}, "" + } else { + rr.Type = uint8(i) + } + <-c // _BLANK + l = <-c + rr.FingerPrint = l.token + return rr, nil, "" +} + +func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(DNSKEY) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad DNSKEY Flags", l}, "" + } else { + rr.Flags = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad DNSKEY Protocol", l}, "" + } else { + rr.Protocol = uint8(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad DNSKEY Algorithm", l}, "" + } else { + rr.Algorithm = uint8(i) + } + s, e, c1 := endingToString(c, "bad DNSKEY PublicKey", f) + if e != nil { + return nil, e, c1 + } + rr.PublicKey = s + return rr, nil, c1 +} + +func setRKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(RKEY) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad RKEY Flags", l}, "" + } else { + rr.Flags = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad RKEY Protocol", l}, "" + } else { + rr.Protocol = uint8(i) + } + <-c // _BLANK + l = <-c // _STRING + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad RKEY Algorithm", l}, "" + } else { + rr.Algorithm = uint8(i) + } + s, e, c1 := endingToString(c, "bad RKEY PublicKey", f) + if e != nil { + return nil, e, c1 + } + rr.PublicKey = s + return rr, nil, c1 +} + +func setDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(DS) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad DS KeyTag", l}, "" + } else { + rr.KeyTag = uint16(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + if i, ok := StringToAlgorithm[l.tokenUpper]; !ok { + return nil, &ParseError{f, "bad DS Algorithm", l}, "" + } else { + rr.Algorithm = i + } + } else { + rr.Algorithm = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad DS DigestType", l}, "" + } else { + rr.DigestType = uint8(i) + } + s, e, c1 := endingToString(c, "bad DS Digest", f) + if e != nil { + return nil, e, c1 + } + rr.Digest = s + return rr, nil, c1 +} + +func setEID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(EID) + rr.Hdr = h + s, e, c1 := endingToString(c, "bad EID Endpoint", f) + if e != nil { + return nil, e, c1 + } + rr.Endpoint = s + return rr, nil, c1 +} + +func setNIMLOC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NIMLOC) + rr.Hdr = h + s, e, c1 := endingToString(c, "bad NIMLOC Locator", f) + if e != nil { + return nil, e, c1 + } + rr.Locator = s + return rr, nil, c1 +} + +func setNSAP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NSAP) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NSAP Length", l}, "" + } else { + rr.Length = uint8(i) + } + <-c // _BLANK + s, e, c1 := endingToString(c, "bad NSAP Nsap", f) + if e != nil { + return nil, e, c1 + } + rr.Nsap = s + return rr, nil, c1 +} + +func setGPOS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(GPOS) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if _, e := strconv.ParseFloat(l.token, 64); e != nil { + return nil, &ParseError{f, "bad GPOS Longitude", l}, "" + } else { + rr.Longitude = l.token + } + <-c // _BLANK + l = <-c + if _, e := strconv.ParseFloat(l.token, 64); e != nil { + return nil, &ParseError{f, "bad GPOS Latitude", l}, "" + } else { + rr.Latitude = l.token + } + <-c // _BLANK + l = <-c + if _, e := strconv.ParseFloat(l.token, 64); e != nil { + return nil, &ParseError{f, "bad GPOS Altitude", l}, "" + } else { + rr.Altitude = l.token + } + return rr, nil, "" +} + +func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(CDS) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad CDS KeyTag", l}, "" + } else { + rr.KeyTag = uint16(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + if i, ok := StringToAlgorithm[l.tokenUpper]; !ok { + return nil, &ParseError{f, "bad CDS Algorithm", l}, "" + } else { + rr.Algorithm = i + } + } else { + rr.Algorithm = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad CDS DigestType", l}, "" + } else { + rr.DigestType = uint8(i) + } + s, e, c1 := endingToString(c, "bad CDS Digest", f) + if e != nil { + return nil, e, c1 + } + rr.Digest = s + return rr, nil, c1 +} + +func setDLV(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(DLV) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad DLV KeyTag", l}, "" + } else { + rr.KeyTag = uint16(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + if i, ok := StringToAlgorithm[l.tokenUpper]; !ok { + return nil, &ParseError{f, "bad DLV Algorithm", l}, "" + } else { + rr.Algorithm = i + } + } else { + rr.Algorithm = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad DLV DigestType", l}, "" + } else { + rr.DigestType = uint8(i) + } + s, e, c1 := endingToString(c, "bad DLV Digest", f) + if e != nil { + return nil, e, c1 + } + rr.Digest = s + return rr, nil, c1 +} + +func setTA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(TA) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad TA KeyTag", l}, "" + } else { + rr.KeyTag = uint16(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + if i, ok := StringToAlgorithm[l.tokenUpper]; !ok { + return nil, &ParseError{f, "bad TA Algorithm", l}, "" + } else { + rr.Algorithm = i + } + } else { + rr.Algorithm = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad TA DigestType", l}, "" + } else { + rr.DigestType = uint8(i) + } + s, e, c1 := endingToString(c, "bad TA Digest", f) + if e != nil { + return nil, e, c1 + } + rr.Digest = s + return rr, nil, c1 +} + +func setTLSA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(TLSA) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad TLSA Usage", l}, "" + } else { + rr.Usage = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad TLSA Selector", l}, "" + } else { + rr.Selector = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad TLSA MatchingType", l}, "" + } else { + rr.MatchingType = uint8(i) + } + s, e, c1 := endingToString(c, "bad TLSA Certificate", f) + if e != nil { + return nil, e, c1 + } + rr.Certificate = s + return rr, nil, c1 +} + +func setRFC3597(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(RFC3597) + rr.Hdr = h + l := <-c + if l.token != "\\#" { + return nil, &ParseError{f, "bad RFC3597 Rdata", l}, "" + } + <-c // _BLANK + l = <-c + rdlength, e := strconv.Atoi(l.token) + if e != nil { + return nil, &ParseError{f, "bad RFC3597 Rdata ", l}, "" + } + + s, e1, c1 := endingToString(c, "bad RFC3597 Rdata", f) + if e1 != nil { + return nil, e1, c1 + } + if rdlength*2 != len(s) { + return nil, &ParseError{f, "bad RFC3597 Rdata", l}, "" + } + rr.Rdata = s + return rr, nil, c1 +} + +func setSPF(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(SPF) + rr.Hdr = h + + s, e, c1 := endingToTxtSlice(c, "bad SPF Txt", f) + if e != nil { + return nil, e, "" + } + rr.Txt = s + return rr, nil, c1 +} + +func setTXT(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(TXT) + rr.Hdr = h + + // No _BLANK reading here, because this is all rdata is TXT + s, e, c1 := endingToTxtSlice(c, "bad TXT Txt", f) + if e != nil { + return nil, e, "" + } + rr.Txt = s + return rr, nil, c1 +} + +// identical to setTXT +func setNINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NINFO) + rr.Hdr = h + + s, e, c1 := endingToTxtSlice(c, "bad NINFO ZSData", f) + if e != nil { + return nil, e, "" + } + rr.ZSData = s + return rr, nil, c1 +} + +func setURI(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(URI) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad URI Priority", l}, "" + } else { + rr.Priority = uint16(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad URI Weight", l}, "" + } else { + rr.Weight = uint16(i) + } + + <-c // _BLANK + s, e, c1 := endingToTxtSlice(c, "bad URI Target", f) + if e != nil { + return nil, e, "" + } + rr.Target = s + return rr, nil, c1 +} + +func setIPSECKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(IPSECKEY) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, l.comment + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad IPSECKEY Precedence", l}, "" + } else { + rr.Precedence = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad IPSECKEY GatewayType", l}, "" + } else { + rr.GatewayType = uint8(i) + } + <-c // _BLANK + l = <-c + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad IPSECKEY Algorithm", l}, "" + } else { + rr.Algorithm = uint8(i) + } + <-c + l = <-c + rr.Gateway = l.token + s, e, c1 := endingToString(c, "bad IPSECKEY PublicKey", f) + if e != nil { + return nil, e, c1 + } + rr.PublicKey = s + return rr, nil, c1 +} + +func setDHCID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + // awesome record to parse! + rr := new(DHCID) + rr.Hdr = h + + s, e, c1 := endingToString(c, "bad DHCID Digest", f) + if e != nil { + return nil, e, c1 + } + rr.Digest = s + return rr, nil, c1 +} + +func setNID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(NID) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad NID Preference", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + u, err := stringToNodeID(l) + if err != nil { + return nil, err, "" + } + rr.NodeID = u + return rr, nil, "" +} + +func setL32(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(L32) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad L32 Preference", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Locator32 = net.ParseIP(l.token) + if rr.Locator32 == nil { + return nil, &ParseError{f, "bad L32 Locator", l}, "" + } + return rr, nil, "" +} + +func setLP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(LP) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad LP Preference", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Fqdn = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Fqdn = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad LP Fqdn", l}, "" + } + if rr.Fqdn[l.length-1] != '.' { + rr.Fqdn = appendOrigin(rr.Fqdn, o) + } + return rr, nil, "" +} + +func setL64(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(L64) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad L64 Preference", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + u, err := stringToNodeID(l) + if err != nil { + return nil, err, "" + } + rr.Locator64 = u + return rr, nil, "" +} + +func setUID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(UID) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad UID Uid", l}, "" + } else { + rr.Uid = uint32(i) + } + return rr, nil, "" +} + +func setGID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(GID) + rr.Hdr = h + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad GID Gid", l}, "" + } else { + rr.Gid = uint32(i) + } + return rr, nil, "" +} + +func setUINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(UINFO) + rr.Hdr = h + s, e, c1 := endingToTxtSlice(c, "bad UINFO Uinfo", f) + if e != nil { + return nil, e, "" + } + rr.Uinfo = s[0] // silently discard anything above + return rr, nil, c1 +} + +func setPX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { + rr := new(PX) + rr.Hdr = h + + l := <-c + if l.length == 0 { + return rr, nil, "" + } + if i, e := strconv.Atoi(l.token); e != nil { + return nil, &ParseError{f, "bad PX Preference", l}, "" + } else { + rr.Preference = uint16(i) + } + <-c // _BLANK + l = <-c // _STRING + rr.Map822 = l.token + if l.length == 0 { + return rr, nil, "" + } + if l.token == "@" { + rr.Map822 = o + return rr, nil, "" + } + _, ok := IsDomainName(l.token) + if !ok { + return nil, &ParseError{f, "bad PX Map822", l}, "" + } + if rr.Map822[l.length-1] != '.' { + rr.Map822 = appendOrigin(rr.Map822, o) + } + <-c // _BLANK + l = <-c // _STRING + rr.Mapx400 = l.token + if l.token == "@" { + rr.Mapx400 = o + return rr, nil, "" + } + _, ok = IsDomainName(l.token) + if !ok || l.length == 0 { + return nil, &ParseError{f, "bad PX Mapx400", l}, "" + } + if rr.Mapx400[l.length-1] != '.' { + rr.Mapx400 = appendOrigin(rr.Mapx400, o) + } + return rr, nil, "" +} + +var typeToparserFunc = map[uint16]parserFunc{ + TypeAAAA: parserFunc{setAAAA, false}, + TypeAFSDB: parserFunc{setAFSDB, false}, + TypeA: parserFunc{setA, false}, + TypeCDS: parserFunc{setCDS, true}, + TypeCERT: parserFunc{setCERT, true}, + TypeCNAME: parserFunc{setCNAME, false}, + TypeDHCID: parserFunc{setDHCID, true}, + TypeDLV: parserFunc{setDLV, true}, + TypeDNAME: parserFunc{setDNAME, false}, + TypeDNSKEY: parserFunc{setDNSKEY, true}, + TypeDS: parserFunc{setDS, true}, + TypeEID: parserFunc{setEID, true}, + TypeEUI48: parserFunc{setEUI48, false}, + TypeEUI64: parserFunc{setEUI64, false}, + TypeGID: parserFunc{setGID, false}, + TypeGPOS: parserFunc{setGPOS, false}, + TypeHINFO: parserFunc{setHINFO, false}, + TypeHIP: parserFunc{setHIP, true}, + TypeIPSECKEY: parserFunc{setIPSECKEY, true}, + TypeKX: parserFunc{setKX, false}, + TypeL32: parserFunc{setL32, false}, + TypeL64: parserFunc{setL64, false}, + TypeLOC: parserFunc{setLOC, true}, + TypeLP: parserFunc{setLP, false}, + TypeMB: parserFunc{setMB, false}, + TypeMD: parserFunc{setMD, false}, + TypeMF: parserFunc{setMF, false}, + TypeMG: parserFunc{setMG, false}, + TypeMINFO: parserFunc{setMINFO, false}, + TypeMR: parserFunc{setMR, false}, + TypeMX: parserFunc{setMX, false}, + TypeNAPTR: parserFunc{setNAPTR, false}, + TypeNID: parserFunc{setNID, false}, + TypeNIMLOC: parserFunc{setNIMLOC, true}, + TypeNINFO: parserFunc{setNINFO, true}, + TypeNSAP: parserFunc{setNSAP, true}, + TypeNSAPPTR: parserFunc{setNSAPPTR, false}, + TypeNSEC3PARAM: parserFunc{setNSEC3PARAM, false}, + TypeNSEC3: parserFunc{setNSEC3, true}, + TypeNSEC: parserFunc{setNSEC, true}, + TypeNS: parserFunc{setNS, false}, + TypeOPENPGPKEY: parserFunc{setOPENPGPKEY, true}, + TypePTR: parserFunc{setPTR, false}, + TypePX: parserFunc{setPX, false}, + TypeRKEY: parserFunc{setRKEY, true}, + TypeRP: parserFunc{setRP, false}, + TypeRRSIG: parserFunc{setRRSIG, true}, + TypeRT: parserFunc{setRT, false}, + TypeSOA: parserFunc{setSOA, false}, + TypeSPF: parserFunc{setSPF, true}, + TypeSRV: parserFunc{setSRV, false}, + TypeSSHFP: parserFunc{setSSHFP, false}, + TypeTALINK: parserFunc{setTALINK, false}, + TypeTA: parserFunc{setTA, true}, + TypeTLSA: parserFunc{setTLSA, true}, + TypeTXT: parserFunc{setTXT, true}, + TypeUID: parserFunc{setUID, false}, + TypeUINFO: parserFunc{setUINFO, true}, + TypeURI: parserFunc{setURI, true}, + TypeWKS: parserFunc{setWKS, true}, + TypeX25: parserFunc{setX25, false}, +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch.go new file mode 100644 index 0000000000..ccf390c9cf --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch.go @@ -0,0 +1,252 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "encoding/binary" + "fmt" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/memdb" +) + +type ErrBatchCorrupted struct { + Reason string +} + +func (e *ErrBatchCorrupted) Error() string { + return fmt.Sprintf("leveldb: batch corrupted: %s", e.Reason) +} + +func newErrBatchCorrupted(reason string) error { + return errors.NewErrCorrupted(nil, &ErrBatchCorrupted{reason}) +} + +const ( + batchHdrLen = 8 + 4 + batchGrowRec = 3000 +) + +type BatchReplay interface { + Put(key, value []byte) + Delete(key []byte) +} + +// Batch is a write batch. +type Batch struct { + data []byte + rLen, bLen int + seq uint64 + sync bool +} + +func (b *Batch) grow(n int) { + off := len(b.data) + if off == 0 { + off = batchHdrLen + if b.data != nil { + b.data = b.data[:off] + } + } + if cap(b.data)-off < n { + if b.data == nil { + b.data = make([]byte, off, off+n) + } else { + odata := b.data + div := 1 + if b.rLen > batchGrowRec { + div = b.rLen / batchGrowRec + } + b.data = make([]byte, off, off+n+(off-batchHdrLen)/div) + copy(b.data, odata) + } + } +} + +func (b *Batch) appendRec(kt kType, key, value []byte) { + n := 1 + binary.MaxVarintLen32 + len(key) + if kt == ktVal { + n += binary.MaxVarintLen32 + len(value) + } + b.grow(n) + off := len(b.data) + data := b.data[:off+n] + data[off] = byte(kt) + off += 1 + off += binary.PutUvarint(data[off:], uint64(len(key))) + copy(data[off:], key) + off += len(key) + if kt == ktVal { + off += binary.PutUvarint(data[off:], uint64(len(value))) + copy(data[off:], value) + off += len(value) + } + b.data = data[:off] + b.rLen++ + // Include 8-byte ikey header + b.bLen += len(key) + len(value) + 8 +} + +// Put appends 'put operation' of the given key/value pair to the batch. +// It is safe to modify the contents of the argument after Put returns. +func (b *Batch) Put(key, value []byte) { + b.appendRec(ktVal, key, value) +} + +// Delete appends 'delete operation' of the given key to the batch. +// It is safe to modify the contents of the argument after Delete returns. +func (b *Batch) Delete(key []byte) { + b.appendRec(ktDel, key, nil) +} + +// Dump dumps batch contents. The returned slice can be loaded into the +// batch using Load method. +// The returned slice is not its own copy, so the contents should not be +// modified. +func (b *Batch) Dump() []byte { + return b.encode() +} + +// Load loads given slice into the batch. Previous contents of the batch +// will be discarded. +// The given slice will not be copied and will be used as batch buffer, so +// it is not safe to modify the contents of the slice. +func (b *Batch) Load(data []byte) error { + return b.decode(0, data) +} + +// Replay replays batch contents. +func (b *Batch) Replay(r BatchReplay) error { + return b.decodeRec(func(i int, kt kType, key, value []byte) { + switch kt { + case ktVal: + r.Put(key, value) + case ktDel: + r.Delete(key) + } + }) +} + +// Len returns number of records in the batch. +func (b *Batch) Len() int { + return b.rLen +} + +// Reset resets the batch. +func (b *Batch) Reset() { + b.data = b.data[:0] + b.seq = 0 + b.rLen = 0 + b.bLen = 0 + b.sync = false +} + +func (b *Batch) init(sync bool) { + b.sync = sync +} + +func (b *Batch) append(p *Batch) { + if p.rLen > 0 { + b.grow(len(p.data) - batchHdrLen) + b.data = append(b.data, p.data[batchHdrLen:]...) + b.rLen += p.rLen + } + if p.sync { + b.sync = true + } +} + +// size returns sums of key/value pair length plus 8-bytes ikey. +func (b *Batch) size() int { + return b.bLen +} + +func (b *Batch) encode() []byte { + b.grow(0) + binary.LittleEndian.PutUint64(b.data, b.seq) + binary.LittleEndian.PutUint32(b.data[8:], uint32(b.rLen)) + + return b.data +} + +func (b *Batch) decode(prevSeq uint64, data []byte) error { + if len(data) < batchHdrLen { + return newErrBatchCorrupted("too short") + } + + b.seq = binary.LittleEndian.Uint64(data) + if b.seq < prevSeq { + return newErrBatchCorrupted("invalid sequence number") + } + b.rLen = int(binary.LittleEndian.Uint32(data[8:])) + if b.rLen < 0 { + return newErrBatchCorrupted("invalid records length") + } + // No need to be precise at this point, it won't be used anyway + b.bLen = len(data) - batchHdrLen + b.data = data + + return nil +} + +func (b *Batch) decodeRec(f func(i int, kt kType, key, value []byte)) (err error) { + off := batchHdrLen + for i := 0; i < b.rLen; i++ { + if off >= len(b.data) { + return newErrBatchCorrupted("invalid records length") + } + + kt := kType(b.data[off]) + if kt > ktVal { + return newErrBatchCorrupted("bad record: invalid type") + } + off += 1 + + x, n := binary.Uvarint(b.data[off:]) + off += n + if n <= 0 || off+int(x) > len(b.data) { + return newErrBatchCorrupted("bad record: invalid key length") + } + key := b.data[off : off+int(x)] + off += int(x) + var value []byte + if kt == ktVal { + x, n := binary.Uvarint(b.data[off:]) + off += n + if n <= 0 || off+int(x) > len(b.data) { + return newErrBatchCorrupted("bad record: invalid value length") + } + value = b.data[off : off+int(x)] + off += int(x) + } + + f(i, kt, key, value) + } + + return nil +} + +func (b *Batch) memReplay(to *memdb.DB) error { + return b.decodeRec(func(i int, kt kType, key, value []byte) { + ikey := newIkey(key, b.seq+uint64(i), kt) + to.Put(ikey, value) + }) +} + +func (b *Batch) memDecodeAndReplay(prevSeq uint64, data []byte, to *memdb.DB) error { + if err := b.decode(prevSeq, data); err != nil { + return err + } + return b.memReplay(to) +} + +func (b *Batch) revertMemReplay(to *memdb.DB) error { + return b.decodeRec(func(i int, kt kType, key, value []byte) { + ikey := newIkey(key, b.seq+uint64(i), kt) + to.Delete(ikey) + }) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch_test.go new file mode 100644 index 0000000000..7fc842f4fe --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/batch_test.go @@ -0,0 +1,120 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "testing" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/memdb" +) + +type tbRec struct { + kt kType + key, value []byte +} + +type testBatch struct { + rec []*tbRec +} + +func (p *testBatch) Put(key, value []byte) { + p.rec = append(p.rec, &tbRec{ktVal, key, value}) +} + +func (p *testBatch) Delete(key []byte) { + p.rec = append(p.rec, &tbRec{ktDel, key, nil}) +} + +func compareBatch(t *testing.T, b1, b2 *Batch) { + if b1.seq != b2.seq { + t.Errorf("invalid seq number want %d, got %d", b1.seq, b2.seq) + } + if b1.Len() != b2.Len() { + t.Fatalf("invalid record length want %d, got %d", b1.Len(), b2.Len()) + } + p1, p2 := new(testBatch), new(testBatch) + err := b1.Replay(p1) + if err != nil { + t.Fatal("error when replaying batch 1: ", err) + } + err = b2.Replay(p2) + if err != nil { + t.Fatal("error when replaying batch 2: ", err) + } + for i := range p1.rec { + r1, r2 := p1.rec[i], p2.rec[i] + if r1.kt != r2.kt { + t.Errorf("invalid type on record '%d' want %d, got %d", i, r1.kt, r2.kt) + } + if !bytes.Equal(r1.key, r2.key) { + t.Errorf("invalid key on record '%d' want %s, got %s", i, string(r1.key), string(r2.key)) + } + if r1.kt == ktVal { + if !bytes.Equal(r1.value, r2.value) { + t.Errorf("invalid value on record '%d' want %s, got %s", i, string(r1.value), string(r2.value)) + } + } + } +} + +func TestBatch_EncodeDecode(t *testing.T) { + b1 := new(Batch) + b1.seq = 10009 + b1.Put([]byte("key1"), []byte("value1")) + b1.Put([]byte("key2"), []byte("value2")) + b1.Delete([]byte("key1")) + b1.Put([]byte("k"), []byte("")) + b1.Put([]byte("zzzzzzzzzzz"), []byte("zzzzzzzzzzzzzzzzzzzzzzzz")) + b1.Delete([]byte("key10000")) + b1.Delete([]byte("k")) + buf := b1.encode() + b2 := new(Batch) + err := b2.decode(0, buf) + if err != nil { + t.Error("error when decoding batch: ", err) + } + compareBatch(t, b1, b2) +} + +func TestBatch_Append(t *testing.T) { + b1 := new(Batch) + b1.seq = 10009 + b1.Put([]byte("key1"), []byte("value1")) + b1.Put([]byte("key2"), []byte("value2")) + b1.Delete([]byte("key1")) + b1.Put([]byte("foo"), []byte("foovalue")) + b1.Put([]byte("bar"), []byte("barvalue")) + b2a := new(Batch) + b2a.seq = 10009 + b2a.Put([]byte("key1"), []byte("value1")) + b2a.Put([]byte("key2"), []byte("value2")) + b2a.Delete([]byte("key1")) + b2b := new(Batch) + b2b.Put([]byte("foo"), []byte("foovalue")) + b2b.Put([]byte("bar"), []byte("barvalue")) + b2a.append(b2b) + compareBatch(t, b1, b2a) +} + +func TestBatch_Size(t *testing.T) { + b := new(Batch) + for i := 0; i < 2; i++ { + b.Put([]byte("key1"), []byte("value1")) + b.Put([]byte("key2"), []byte("value2")) + b.Delete([]byte("key1")) + b.Put([]byte("foo"), []byte("foovalue")) + b.Put([]byte("bar"), []byte("barvalue")) + mem := memdb.New(&iComparer{comparer.DefaultComparer}, 0) + b.memReplay(mem) + if b.size() != mem.Size() { + t.Errorf("invalid batch size calculation, want=%d got=%d", mem.Size(), b.size()) + } + b.Reset() + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/bench_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/bench_test.go new file mode 100644 index 0000000000..91b426709d --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/bench_test.go @@ -0,0 +1,464 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "fmt" + "math/rand" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +func randomString(r *rand.Rand, n int) []byte { + b := new(bytes.Buffer) + for i := 0; i < n; i++ { + b.WriteByte(' ' + byte(r.Intn(95))) + } + return b.Bytes() +} + +func compressibleStr(r *rand.Rand, frac float32, n int) []byte { + nn := int(float32(n) * frac) + rb := randomString(r, nn) + b := make([]byte, 0, n+nn) + for len(b) < n { + b = append(b, rb...) + } + return b[:n] +} + +type valueGen struct { + src []byte + pos int +} + +func newValueGen(frac float32) *valueGen { + v := new(valueGen) + r := rand.New(rand.NewSource(301)) + v.src = make([]byte, 0, 1048576+100) + for len(v.src) < 1048576 { + v.src = append(v.src, compressibleStr(r, frac, 100)...) + } + return v +} + +func (v *valueGen) get(n int) []byte { + if v.pos+n > len(v.src) { + v.pos = 0 + } + v.pos += n + return v.src[v.pos-n : v.pos] +} + +var benchDB = filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbbench-%d", os.Getuid())) + +type dbBench struct { + b *testing.B + stor storage.Storage + db *DB + + o *opt.Options + ro *opt.ReadOptions + wo *opt.WriteOptions + + keys, values [][]byte +} + +func openDBBench(b *testing.B, noCompress bool) *dbBench { + _, err := os.Stat(benchDB) + if err == nil { + err = os.RemoveAll(benchDB) + if err != nil { + b.Fatal("cannot remove old db: ", err) + } + } + + p := &dbBench{ + b: b, + o: &opt.Options{}, + ro: &opt.ReadOptions{}, + wo: &opt.WriteOptions{}, + } + p.stor, err = storage.OpenFile(benchDB) + if err != nil { + b.Fatal("cannot open stor: ", err) + } + if noCompress { + p.o.Compression = opt.NoCompression + } + + p.db, err = Open(p.stor, p.o) + if err != nil { + b.Fatal("cannot open db: ", err) + } + + runtime.GOMAXPROCS(runtime.NumCPU()) + return p +} + +func (p *dbBench) reopen() { + p.db.Close() + var err error + p.db, err = Open(p.stor, p.o) + if err != nil { + p.b.Fatal("Reopen: got error: ", err) + } +} + +func (p *dbBench) populate(n int) { + p.keys, p.values = make([][]byte, n), make([][]byte, n) + v := newValueGen(0.5) + for i := range p.keys { + p.keys[i], p.values[i] = []byte(fmt.Sprintf("%016d", i)), v.get(100) + } +} + +func (p *dbBench) randomize() { + m := len(p.keys) + times := m * 2 + r1, r2 := rand.New(rand.NewSource(0xdeadbeef)), rand.New(rand.NewSource(0xbeefface)) + for n := 0; n < times; n++ { + i, j := r1.Int()%m, r2.Int()%m + if i == j { + continue + } + p.keys[i], p.keys[j] = p.keys[j], p.keys[i] + p.values[i], p.values[j] = p.values[j], p.values[i] + } +} + +func (p *dbBench) writes(perBatch int) { + b := p.b + db := p.db + + n := len(p.keys) + m := n / perBatch + if n%perBatch > 0 { + m++ + } + batches := make([]Batch, m) + j := 0 + for i := range batches { + first := true + for ; j < n && ((j+1)%perBatch != 0 || first); j++ { + first = false + batches[i].Put(p.keys[j], p.values[j]) + } + } + runtime.GC() + + b.ResetTimer() + b.StartTimer() + for i := range batches { + err := db.Write(&(batches[i]), p.wo) + if err != nil { + b.Fatal("write failed: ", err) + } + } + b.StopTimer() + b.SetBytes(116) +} + +func (p *dbBench) gc() { + p.keys, p.values = nil, nil + runtime.GC() +} + +func (p *dbBench) puts() { + b := p.b + db := p.db + + b.ResetTimer() + b.StartTimer() + for i := range p.keys { + err := db.Put(p.keys[i], p.values[i], p.wo) + if err != nil { + b.Fatal("put failed: ", err) + } + } + b.StopTimer() + b.SetBytes(116) +} + +func (p *dbBench) fill() { + b := p.b + db := p.db + + perBatch := 10000 + batch := new(Batch) + for i, n := 0, len(p.keys); i < n; { + first := true + for ; i < n && ((i+1)%perBatch != 0 || first); i++ { + first = false + batch.Put(p.keys[i], p.values[i]) + } + err := db.Write(batch, p.wo) + if err != nil { + b.Fatal("write failed: ", err) + } + batch.Reset() + } +} + +func (p *dbBench) gets() { + b := p.b + db := p.db + + b.ResetTimer() + for i := range p.keys { + _, err := db.Get(p.keys[i], p.ro) + if err != nil { + b.Error("got error: ", err) + } + } + b.StopTimer() +} + +func (p *dbBench) seeks() { + b := p.b + + iter := p.newIter() + defer iter.Release() + b.ResetTimer() + for i := range p.keys { + if !iter.Seek(p.keys[i]) { + b.Error("value not found for: ", string(p.keys[i])) + } + } + b.StopTimer() +} + +func (p *dbBench) newIter() iterator.Iterator { + iter := p.db.NewIterator(nil, p.ro) + err := iter.Error() + if err != nil { + p.b.Fatal("cannot create iterator: ", err) + } + return iter +} + +func (p *dbBench) close() { + if bp, err := p.db.GetProperty("leveldb.blockpool"); err == nil { + p.b.Log("Block pool stats: ", bp) + } + p.db.Close() + p.stor.Close() + os.RemoveAll(benchDB) + p.db = nil + p.keys = nil + p.values = nil + runtime.GC() + runtime.GOMAXPROCS(1) +} + +func BenchmarkDBWrite(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1) + p.close() +} + +func BenchmarkDBWriteBatch(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1000) + p.close() +} + +func BenchmarkDBWriteUncompressed(b *testing.B) { + p := openDBBench(b, true) + p.populate(b.N) + p.writes(1) + p.close() +} + +func BenchmarkDBWriteBatchUncompressed(b *testing.B) { + p := openDBBench(b, true) + p.populate(b.N) + p.writes(1000) + p.close() +} + +func BenchmarkDBWriteRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.randomize() + p.writes(1) + p.close() +} + +func BenchmarkDBWriteRandomSync(b *testing.B) { + p := openDBBench(b, false) + p.wo.Sync = true + p.populate(b.N) + p.writes(1) + p.close() +} + +func BenchmarkDBOverwrite(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1) + p.writes(1) + p.close() +} + +func BenchmarkDBOverwriteRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.writes(1) + p.randomize() + p.writes(1) + p.close() +} + +func BenchmarkDBPut(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.puts() + p.close() +} + +func BenchmarkDBRead(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadGC(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadUncompressed(b *testing.B) { + p := openDBBench(b, true) + p.populate(b.N) + p.fill() + p.gc() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadTable(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.reopen() + p.gc() + + iter := p.newIter() + b.ResetTimer() + for iter.Next() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadReverse(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + + iter := p.newIter() + b.ResetTimer() + iter.Last() + for iter.Prev() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBReadReverseTable(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.reopen() + p.gc() + + iter := p.newIter() + b.ResetTimer() + iter.Last() + for iter.Prev() { + } + iter.Release() + b.StopTimer() + b.SetBytes(116) + p.close() +} + +func BenchmarkDBSeek(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.seeks() + p.close() +} + +func BenchmarkDBSeekRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.randomize() + p.seeks() + p.close() +} + +func BenchmarkDBGet(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gets() + p.close() +} + +func BenchmarkDBGetRandom(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.randomize() + p.gets() + p.close() +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache.go new file mode 100644 index 0000000000..c9670de5de --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache.go @@ -0,0 +1,676 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package cache provides interface and implementation of a cache algorithms. +package cache + +import ( + "sync" + "sync/atomic" + "unsafe" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +// Cacher provides interface to implements a caching functionality. +// An implementation must be goroutine-safe. +type Cacher interface { + // Capacity returns cache capacity. + Capacity() int + + // SetCapacity sets cache capacity. + SetCapacity(capacity int) + + // Promote promotes the 'cache node'. + Promote(n *Node) + + // Ban evicts the 'cache node' and prevent subsequent 'promote'. + Ban(n *Node) + + // Evict evicts the 'cache node'. + Evict(n *Node) + + // EvictNS evicts 'cache node' with the given namespace. + EvictNS(ns uint64) + + // EvictAll evicts all 'cache node'. + EvictAll() + + // Close closes the 'cache tree' + Close() error +} + +// Value is a 'cacheable object'. It may implements util.Releaser, if +// so the the Release method will be called once object is released. +type Value interface{} + +type CacheGetter struct { + Cache *Cache + NS uint64 +} + +func (g *CacheGetter) Get(key uint64, setFunc func() (size int, value Value)) *Handle { + return g.Cache.Get(g.NS, key, setFunc) +} + +// The hash tables implementation is based on: +// "Dynamic-Sized Nonblocking Hash Tables", by Yujie Liu, Kunlong Zhang, and Michael Spear. ACM Symposium on Principles of Distributed Computing, Jul 2014. + +const ( + mInitialSize = 1 << 4 + mOverflowThreshold = 1 << 5 + mOverflowGrowThreshold = 1 << 7 +) + +type mBucket struct { + mu sync.Mutex + node []*Node + frozen bool +} + +func (b *mBucket) freeze() []*Node { + b.mu.Lock() + defer b.mu.Unlock() + if !b.frozen { + b.frozen = true + } + return b.node +} + +func (b *mBucket) get(r *Cache, h *mNode, hash uint32, ns, key uint64, noset bool) (done, added bool, n *Node) { + b.mu.Lock() + + if b.frozen { + b.mu.Unlock() + return + } + + // Scan the node. + for _, n := range b.node { + if n.hash == hash && n.ns == ns && n.key == key { + atomic.AddInt32(&n.ref, 1) + b.mu.Unlock() + return true, false, n + } + } + + // Get only. + if noset { + b.mu.Unlock() + return true, false, nil + } + + // Create node. + n = &Node{ + r: r, + hash: hash, + ns: ns, + key: key, + ref: 1, + } + // Add node to bucket. + b.node = append(b.node, n) + bLen := len(b.node) + b.mu.Unlock() + + // Update counter. + grow := atomic.AddInt32(&r.nodes, 1) >= h.growThreshold + if bLen > mOverflowThreshold { + grow = grow || atomic.AddInt32(&h.overflow, 1) >= mOverflowGrowThreshold + } + + // Grow. + if grow && atomic.CompareAndSwapInt32(&h.resizeInProgess, 0, 1) { + nhLen := len(h.buckets) << 1 + nh := &mNode{ + buckets: make([]unsafe.Pointer, nhLen), + mask: uint32(nhLen) - 1, + pred: unsafe.Pointer(h), + growThreshold: int32(nhLen * mOverflowThreshold), + shrinkThreshold: int32(nhLen >> 1), + } + ok := atomic.CompareAndSwapPointer(&r.mHead, unsafe.Pointer(h), unsafe.Pointer(nh)) + if !ok { + panic("BUG: failed swapping head") + } + go nh.initBuckets() + } + + return true, true, n +} + +func (b *mBucket) delete(r *Cache, h *mNode, hash uint32, ns, key uint64) (done, deleted bool) { + b.mu.Lock() + + if b.frozen { + b.mu.Unlock() + return + } + + // Scan the node. + var ( + n *Node + bLen int + ) + for i := range b.node { + n = b.node[i] + if n.ns == ns && n.key == key { + if atomic.LoadInt32(&n.ref) == 0 { + deleted = true + + // Call releaser. + if n.value != nil { + if r, ok := n.value.(util.Releaser); ok { + r.Release() + } + n.value = nil + } + + // Remove node from bucket. + b.node = append(b.node[:i], b.node[i+1:]...) + bLen = len(b.node) + } + break + } + } + b.mu.Unlock() + + if deleted { + // Call OnDel. + for _, f := range n.onDel { + f() + } + + // Update counter. + atomic.AddInt32(&r.size, int32(n.size)*-1) + shrink := atomic.AddInt32(&r.nodes, -1) < h.shrinkThreshold + if bLen >= mOverflowThreshold { + atomic.AddInt32(&h.overflow, -1) + } + + // Shrink. + if shrink && len(h.buckets) > mInitialSize && atomic.CompareAndSwapInt32(&h.resizeInProgess, 0, 1) { + nhLen := len(h.buckets) >> 1 + nh := &mNode{ + buckets: make([]unsafe.Pointer, nhLen), + mask: uint32(nhLen) - 1, + pred: unsafe.Pointer(h), + growThreshold: int32(nhLen * mOverflowThreshold), + shrinkThreshold: int32(nhLen >> 1), + } + ok := atomic.CompareAndSwapPointer(&r.mHead, unsafe.Pointer(h), unsafe.Pointer(nh)) + if !ok { + panic("BUG: failed swapping head") + } + go nh.initBuckets() + } + } + + return true, deleted +} + +type mNode struct { + buckets []unsafe.Pointer // []*mBucket + mask uint32 + pred unsafe.Pointer // *mNode + resizeInProgess int32 + + overflow int32 + growThreshold int32 + shrinkThreshold int32 +} + +func (n *mNode) initBucket(i uint32) *mBucket { + if b := (*mBucket)(atomic.LoadPointer(&n.buckets[i])); b != nil { + return b + } + + p := (*mNode)(atomic.LoadPointer(&n.pred)) + if p != nil { + var node []*Node + if n.mask > p.mask { + // Grow. + pb := (*mBucket)(atomic.LoadPointer(&p.buckets[i&p.mask])) + if pb == nil { + pb = p.initBucket(i & p.mask) + } + m := pb.freeze() + // Split nodes. + for _, x := range m { + if x.hash&n.mask == i { + node = append(node, x) + } + } + } else { + // Shrink. + pb0 := (*mBucket)(atomic.LoadPointer(&p.buckets[i])) + if pb0 == nil { + pb0 = p.initBucket(i) + } + pb1 := (*mBucket)(atomic.LoadPointer(&p.buckets[i+uint32(len(n.buckets))])) + if pb1 == nil { + pb1 = p.initBucket(i + uint32(len(n.buckets))) + } + m0 := pb0.freeze() + m1 := pb1.freeze() + // Merge nodes. + node = make([]*Node, 0, len(m0)+len(m1)) + node = append(node, m0...) + node = append(node, m1...) + } + b := &mBucket{node: node} + if atomic.CompareAndSwapPointer(&n.buckets[i], nil, unsafe.Pointer(b)) { + if len(node) > mOverflowThreshold { + atomic.AddInt32(&n.overflow, int32(len(node)-mOverflowThreshold)) + } + return b + } + } + + return (*mBucket)(atomic.LoadPointer(&n.buckets[i])) +} + +func (n *mNode) initBuckets() { + for i := range n.buckets { + n.initBucket(uint32(i)) + } + atomic.StorePointer(&n.pred, nil) +} + +// Cache is a 'cache map'. +type Cache struct { + mu sync.RWMutex + mHead unsafe.Pointer // *mNode + nodes int32 + size int32 + cacher Cacher + closed bool +} + +// NewCache creates a new 'cache map'. The cacher is optional and +// may be nil. +func NewCache(cacher Cacher) *Cache { + h := &mNode{ + buckets: make([]unsafe.Pointer, mInitialSize), + mask: mInitialSize - 1, + growThreshold: int32(mInitialSize * mOverflowThreshold), + shrinkThreshold: 0, + } + for i := range h.buckets { + h.buckets[i] = unsafe.Pointer(&mBucket{}) + } + r := &Cache{ + mHead: unsafe.Pointer(h), + cacher: cacher, + } + return r +} + +func (r *Cache) getBucket(hash uint32) (*mNode, *mBucket) { + h := (*mNode)(atomic.LoadPointer(&r.mHead)) + i := hash & h.mask + b := (*mBucket)(atomic.LoadPointer(&h.buckets[i])) + if b == nil { + b = h.initBucket(i) + } + return h, b +} + +func (r *Cache) delete(n *Node) bool { + for { + h, b := r.getBucket(n.hash) + done, deleted := b.delete(r, h, n.hash, n.ns, n.key) + if done { + return deleted + } + } + return false +} + +// Nodes returns number of 'cache node' in the map. +func (r *Cache) Nodes() int { + return int(atomic.LoadInt32(&r.nodes)) +} + +// Size returns sums of 'cache node' size in the map. +func (r *Cache) Size() int { + return int(atomic.LoadInt32(&r.size)) +} + +// Capacity returns cache capacity. +func (r *Cache) Capacity() int { + if r.cacher == nil { + return 0 + } + return r.cacher.Capacity() +} + +// SetCapacity sets cache capacity. +func (r *Cache) SetCapacity(capacity int) { + if r.cacher != nil { + r.cacher.SetCapacity(capacity) + } +} + +// Get gets 'cache node' with the given namespace and key. +// If cache node is not found and setFunc is not nil, Get will atomically creates +// the 'cache node' by calling setFunc. Otherwise Get will returns nil. +// +// The returned 'cache handle' should be released after use by calling Release +// method. +func (r *Cache) Get(ns, key uint64, setFunc func() (size int, value Value)) *Handle { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return nil + } + + hash := murmur32(ns, key, 0xf00) + for { + h, b := r.getBucket(hash) + done, _, n := b.get(r, h, hash, ns, key, setFunc == nil) + if done { + if n != nil { + n.mu.Lock() + if n.value == nil { + if setFunc == nil { + n.mu.Unlock() + n.unref() + return nil + } + + n.size, n.value = setFunc() + if n.value == nil { + n.size = 0 + n.mu.Unlock() + n.unref() + return nil + } + atomic.AddInt32(&r.size, int32(n.size)) + } + n.mu.Unlock() + if r.cacher != nil { + r.cacher.Promote(n) + } + return &Handle{unsafe.Pointer(n)} + } + + break + } + } + return nil +} + +// Delete removes and ban 'cache node' with the given namespace and key. +// A banned 'cache node' will never inserted into the 'cache tree'. Ban +// only attributed to the particular 'cache node', so when a 'cache node' +// is recreated it will not be banned. +// +// If onDel is not nil, then it will be executed if such 'cache node' +// doesn't exist or once the 'cache node' is released. +// +// Delete return true is such 'cache node' exist. +func (r *Cache) Delete(ns, key uint64, onDel func()) bool { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return false + } + + hash := murmur32(ns, key, 0xf00) + for { + h, b := r.getBucket(hash) + done, _, n := b.get(r, h, hash, ns, key, true) + if done { + if n != nil { + if onDel != nil { + n.mu.Lock() + n.onDel = append(n.onDel, onDel) + n.mu.Unlock() + } + if r.cacher != nil { + r.cacher.Ban(n) + } + n.unref() + return true + } + + break + } + } + + if onDel != nil { + onDel() + } + + return false +} + +// Evict evicts 'cache node' with the given namespace and key. This will +// simply call Cacher.Evict. +// +// Evict return true is such 'cache node' exist. +func (r *Cache) Evict(ns, key uint64) bool { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return false + } + + hash := murmur32(ns, key, 0xf00) + for { + h, b := r.getBucket(hash) + done, _, n := b.get(r, h, hash, ns, key, true) + if done { + if n != nil { + if r.cacher != nil { + r.cacher.Evict(n) + } + n.unref() + return true + } + + break + } + } + + return false +} + +// EvictNS evicts 'cache node' with the given namespace. This will +// simply call Cacher.EvictNS. +func (r *Cache) EvictNS(ns uint64) { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return + } + + if r.cacher != nil { + r.cacher.EvictNS(ns) + } +} + +// EvictAll evicts all 'cache node'. This will simply call Cacher.EvictAll. +func (r *Cache) EvictAll() { + r.mu.RLock() + defer r.mu.RUnlock() + if r.closed { + return + } + + if r.cacher != nil { + r.cacher.EvictAll() + } +} + +// Close closes the 'cache map' and releases all 'cache node'. +func (r *Cache) Close() error { + r.mu.Lock() + if !r.closed { + r.closed = true + + if r.cacher != nil { + if err := r.cacher.Close(); err != nil { + return err + } + } + + h := (*mNode)(r.mHead) + h.initBuckets() + + for i := range h.buckets { + b := (*mBucket)(h.buckets[i]) + for _, n := range b.node { + // Call releaser. + if n.value != nil { + if r, ok := n.value.(util.Releaser); ok { + r.Release() + } + n.value = nil + } + + // Call OnDel. + for _, f := range n.onDel { + f() + } + } + } + } + r.mu.Unlock() + return nil +} + +// Node is a 'cache node'. +type Node struct { + r *Cache + + hash uint32 + ns, key uint64 + + mu sync.Mutex + size int + value Value + + ref int32 + onDel []func() + + CacheData unsafe.Pointer +} + +// NS returns this 'cache node' namespace. +func (n *Node) NS() uint64 { + return n.ns +} + +// Key returns this 'cache node' key. +func (n *Node) Key() uint64 { + return n.key +} + +// Size returns this 'cache node' size. +func (n *Node) Size() int { + return n.size +} + +// Value returns this 'cache node' value. +func (n *Node) Value() Value { + return n.value +} + +// Ref returns this 'cache node' ref counter. +func (n *Node) Ref() int32 { + return atomic.LoadInt32(&n.ref) +} + +// GetHandle returns an handle for this 'cache node'. +func (n *Node) GetHandle() *Handle { + if atomic.AddInt32(&n.ref, 1) <= 1 { + panic("BUG: Node.GetHandle on zero ref") + } + return &Handle{unsafe.Pointer(n)} +} + +func (n *Node) unref() { + if atomic.AddInt32(&n.ref, -1) == 0 { + n.r.delete(n) + } +} + +func (n *Node) unrefLocked() { + if atomic.AddInt32(&n.ref, -1) == 0 { + n.r.mu.RLock() + if !n.r.closed { + n.r.delete(n) + } + n.r.mu.RUnlock() + } +} + +type Handle struct { + n unsafe.Pointer // *Node +} + +func (h *Handle) Value() Value { + n := (*Node)(atomic.LoadPointer(&h.n)) + if n != nil { + return n.value + } + return nil +} + +func (h *Handle) Release() { + nPtr := atomic.LoadPointer(&h.n) + if nPtr != nil && atomic.CompareAndSwapPointer(&h.n, nPtr, nil) { + n := (*Node)(nPtr) + n.unrefLocked() + } +} + +func murmur32(ns, key uint64, seed uint32) uint32 { + const ( + m = uint32(0x5bd1e995) + r = 24 + ) + + k1 := uint32(ns >> 32) + k2 := uint32(ns) + k3 := uint32(key >> 32) + k4 := uint32(key) + + k1 *= m + k1 ^= k1 >> r + k1 *= m + + k2 *= m + k2 ^= k2 >> r + k2 *= m + + k3 *= m + k3 ^= k3 >> r + k3 *= m + + k4 *= m + k4 ^= k4 >> r + k4 *= m + + h := seed + + h *= m + h ^= k1 + h *= m + h ^= k2 + h *= m + h ^= k3 + h *= m + h ^= k4 + + h ^= h >> 13 + h *= m + h ^= h >> 15 + + return h +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go new file mode 100644 index 0000000000..5575583dce --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/cache_test.go @@ -0,0 +1,570 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package cache + +import ( + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + "unsafe" +) + +type int32o int32 + +func (o *int32o) acquire() { + if atomic.AddInt32((*int32)(o), 1) != 1 { + panic("BUG: invalid ref") + } +} + +func (o *int32o) Release() { + if atomic.AddInt32((*int32)(o), -1) != 0 { + panic("BUG: invalid ref") + } +} + +type releaserFunc struct { + fn func() + value Value +} + +func (r releaserFunc) Release() { + if r.fn != nil { + r.fn() + } +} + +func set(c *Cache, ns, key uint64, value Value, charge int, relf func()) *Handle { + return c.Get(ns, key, func() (int, Value) { + if relf != nil { + return charge, releaserFunc{relf, value} + } else { + return charge, value + } + }) +} + +func TestCacheMap(t *testing.T) { + runtime.GOMAXPROCS(runtime.NumCPU()) + + nsx := []struct { + nobjects, nhandles, concurrent, repeat int + }{ + {10000, 400, 50, 3}, + {100000, 1000, 100, 10}, + } + + var ( + objects [][]int32o + handles [][]unsafe.Pointer + ) + + for _, x := range nsx { + objects = append(objects, make([]int32o, x.nobjects)) + handles = append(handles, make([]unsafe.Pointer, x.nhandles)) + } + + c := NewCache(nil) + + wg := new(sync.WaitGroup) + var done int32 + + for ns, x := range nsx { + for i := 0; i < x.concurrent; i++ { + wg.Add(1) + go func(ns, i, repeat int, objects []int32o, handles []unsafe.Pointer) { + defer wg.Done() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for j := len(objects) * repeat; j >= 0; j-- { + key := uint64(r.Intn(len(objects))) + h := c.Get(uint64(ns), key, func() (int, Value) { + o := &objects[key] + o.acquire() + return 1, o + }) + if v := h.Value().(*int32o); v != &objects[key] { + t.Fatalf("#%d invalid value: want=%p got=%p", ns, &objects[key], v) + } + if objects[key] != 1 { + t.Fatalf("#%d invalid object %d: %d", ns, key, objects[key]) + } + if !atomic.CompareAndSwapPointer(&handles[r.Intn(len(handles))], nil, unsafe.Pointer(h)) { + h.Release() + } + } + }(ns, i, x.repeat, objects[ns], handles[ns]) + } + + go func(handles []unsafe.Pointer) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for atomic.LoadInt32(&done) == 0 { + i := r.Intn(len(handles)) + h := (*Handle)(atomic.LoadPointer(&handles[i])) + if h != nil && atomic.CompareAndSwapPointer(&handles[i], unsafe.Pointer(h), nil) { + h.Release() + } + time.Sleep(time.Millisecond) + } + }(handles[ns]) + } + + go func() { + handles := make([]*Handle, 100000) + for atomic.LoadInt32(&done) == 0 { + for i := range handles { + handles[i] = c.Get(999999999, uint64(i), func() (int, Value) { + return 1, 1 + }) + } + for _, h := range handles { + h.Release() + } + } + }() + + wg.Wait() + + atomic.StoreInt32(&done, 1) + + for _, handles0 := range handles { + for i := range handles0 { + h := (*Handle)(atomic.LoadPointer(&handles0[i])) + if h != nil && atomic.CompareAndSwapPointer(&handles0[i], unsafe.Pointer(h), nil) { + h.Release() + } + } + } + + for ns, objects0 := range objects { + for i, o := range objects0 { + if o != 0 { + t.Fatalf("invalid object #%d.%d: ref=%d", ns, i, o) + } + } + } +} + +func TestCacheMap_NodesAndSize(t *testing.T) { + c := NewCache(nil) + if c.Nodes() != 0 { + t.Errorf("invalid nodes counter: want=%d got=%d", 0, c.Nodes()) + } + if c.Size() != 0 { + t.Errorf("invalid size counter: want=%d got=%d", 0, c.Size()) + } + set(c, 0, 1, 1, 1, nil) + set(c, 0, 2, 2, 2, nil) + set(c, 1, 1, 3, 3, nil) + set(c, 2, 1, 4, 1, nil) + if c.Nodes() != 4 { + t.Errorf("invalid nodes counter: want=%d got=%d", 4, c.Nodes()) + } + if c.Size() != 7 { + t.Errorf("invalid size counter: want=%d got=%d", 4, c.Size()) + } +} + +func TestLRUCache_Capacity(t *testing.T) { + c := NewCache(NewLRU(10)) + if c.Capacity() != 10 { + t.Errorf("invalid capacity: want=%d got=%d", 10, c.Capacity()) + } + set(c, 0, 1, 1, 1, nil).Release() + set(c, 0, 2, 2, 2, nil).Release() + set(c, 1, 1, 3, 3, nil).Release() + set(c, 2, 1, 4, 1, nil).Release() + set(c, 2, 2, 5, 1, nil).Release() + set(c, 2, 3, 6, 1, nil).Release() + set(c, 2, 4, 7, 1, nil).Release() + set(c, 2, 5, 8, 1, nil).Release() + if c.Nodes() != 7 { + t.Errorf("invalid nodes counter: want=%d got=%d", 7, c.Nodes()) + } + if c.Size() != 10 { + t.Errorf("invalid size counter: want=%d got=%d", 10, c.Size()) + } + c.SetCapacity(9) + if c.Capacity() != 9 { + t.Errorf("invalid capacity: want=%d got=%d", 9, c.Capacity()) + } + if c.Nodes() != 6 { + t.Errorf("invalid nodes counter: want=%d got=%d", 6, c.Nodes()) + } + if c.Size() != 8 { + t.Errorf("invalid size counter: want=%d got=%d", 8, c.Size()) + } +} + +func TestCacheMap_NilValue(t *testing.T) { + c := NewCache(NewLRU(10)) + h := c.Get(0, 0, func() (size int, value Value) { + return 1, nil + }) + if h != nil { + t.Error("cache handle is non-nil") + } + if c.Nodes() != 0 { + t.Errorf("invalid nodes counter: want=%d got=%d", 0, c.Nodes()) + } + if c.Size() != 0 { + t.Errorf("invalid size counter: want=%d got=%d", 0, c.Size()) + } +} + +func TestLRUCache_GetLatency(t *testing.T) { + runtime.GOMAXPROCS(runtime.NumCPU()) + + const ( + concurrentSet = 30 + concurrentGet = 3 + duration = 3 * time.Second + delay = 3 * time.Millisecond + maxkey = 100000 + ) + + var ( + set, getHit, getAll int32 + getMaxLatency, getDuration int64 + ) + + c := NewCache(NewLRU(5000)) + wg := &sync.WaitGroup{} + until := time.Now().Add(duration) + for i := 0; i < concurrentSet; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for time.Now().Before(until) { + c.Get(0, uint64(r.Intn(maxkey)), func() (int, Value) { + time.Sleep(delay) + atomic.AddInt32(&set, 1) + return 1, 1 + }).Release() + } + }(i) + } + for i := 0; i < concurrentGet; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for { + mark := time.Now() + if mark.Before(until) { + h := c.Get(0, uint64(r.Intn(maxkey)), nil) + latency := int64(time.Now().Sub(mark)) + m := atomic.LoadInt64(&getMaxLatency) + if latency > m { + atomic.CompareAndSwapInt64(&getMaxLatency, m, latency) + } + atomic.AddInt64(&getDuration, latency) + if h != nil { + atomic.AddInt32(&getHit, 1) + h.Release() + } + atomic.AddInt32(&getAll, 1) + } else { + break + } + } + }(i) + } + + wg.Wait() + getAvglatency := time.Duration(getDuration) / time.Duration(getAll) + t.Logf("set=%d getHit=%d getAll=%d getMaxLatency=%v getAvgLatency=%v", + set, getHit, getAll, time.Duration(getMaxLatency), getAvglatency) + + if getAvglatency > delay/3 { + t.Errorf("get avg latency > %v: got=%v", delay/3, getAvglatency) + } +} + +func TestLRUCache_HitMiss(t *testing.T) { + cases := []struct { + key uint64 + value string + }{ + {1, "vvvvvvvvv"}, + {100, "v1"}, + {0, "v2"}, + {12346, "v3"}, + {777, "v4"}, + {999, "v5"}, + {7654, "v6"}, + {2, "v7"}, + {3, "v8"}, + {9, "v9"}, + } + + setfin := 0 + c := NewCache(NewLRU(1000)) + for i, x := range cases { + set(c, 0, x.key, x.value, len(x.value), func() { + setfin++ + }).Release() + for j, y := range cases { + h := c.Get(0, y.key, nil) + if j <= i { + // should hit + if h == nil { + t.Errorf("case '%d' iteration '%d' is miss", i, j) + } else { + if x := h.Value().(releaserFunc).value.(string); x != y.value { + t.Errorf("case '%d' iteration '%d' has invalid value got '%s', want '%s'", i, j, x, y.value) + } + } + } else { + // should miss + if h != nil { + t.Errorf("case '%d' iteration '%d' is hit , value '%s'", i, j, h.Value().(releaserFunc).value.(string)) + } + } + if h != nil { + h.Release() + } + } + } + + for i, x := range cases { + finalizerOk := false + c.Delete(0, x.key, func() { + finalizerOk = true + }) + + if !finalizerOk { + t.Errorf("case %d delete finalizer not executed", i) + } + + for j, y := range cases { + h := c.Get(0, y.key, nil) + if j > i { + // should hit + if h == nil { + t.Errorf("case '%d' iteration '%d' is miss", i, j) + } else { + if x := h.Value().(releaserFunc).value.(string); x != y.value { + t.Errorf("case '%d' iteration '%d' has invalid value got '%s', want '%s'", i, j, x, y.value) + } + } + } else { + // should miss + if h != nil { + t.Errorf("case '%d' iteration '%d' is hit, value '%s'", i, j, h.Value().(releaserFunc).value.(string)) + } + } + if h != nil { + h.Release() + } + } + } + + if setfin != len(cases) { + t.Errorf("some set finalizer may not be executed, want=%d got=%d", len(cases), setfin) + } +} + +func TestLRUCache_Eviction(t *testing.T) { + c := NewCache(NewLRU(12)) + o1 := set(c, 0, 1, 1, 1, nil) + set(c, 0, 2, 2, 1, nil).Release() + set(c, 0, 3, 3, 1, nil).Release() + set(c, 0, 4, 4, 1, nil).Release() + set(c, 0, 5, 5, 1, nil).Release() + if h := c.Get(0, 2, nil); h != nil { // 1,3,4,5,2 + h.Release() + } + set(c, 0, 9, 9, 10, nil).Release() // 5,2,9 + + for _, key := range []uint64{9, 2, 5, 1} { + h := c.Get(0, key, nil) + if h == nil { + t.Errorf("miss for key '%d'", key) + } else { + if x := h.Value().(int); x != int(key) { + t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x) + } + h.Release() + } + } + o1.Release() + for _, key := range []uint64{1, 2, 5} { + h := c.Get(0, key, nil) + if h == nil { + t.Errorf("miss for key '%d'", key) + } else { + if x := h.Value().(int); x != int(key) { + t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x) + } + h.Release() + } + } + for _, key := range []uint64{3, 4, 9} { + h := c.Get(0, key, nil) + if h != nil { + t.Errorf("hit for key '%d'", key) + if x := h.Value().(int); x != int(key) { + t.Errorf("invalid value for key '%d' want '%d', got '%d'", key, key, x) + } + h.Release() + } + } +} + +func TestLRUCache_Evict(t *testing.T) { + c := NewCache(NewLRU(6)) + set(c, 0, 1, 1, 1, nil).Release() + set(c, 0, 2, 2, 1, nil).Release() + set(c, 1, 1, 4, 1, nil).Release() + set(c, 1, 2, 5, 1, nil).Release() + set(c, 2, 1, 6, 1, nil).Release() + set(c, 2, 2, 7, 1, nil).Release() + + for ns := 0; ns < 3; ns++ { + for key := 1; key < 3; key++ { + if h := c.Get(uint64(ns), uint64(key), nil); h != nil { + h.Release() + } else { + t.Errorf("Cache.Get on #%d.%d return nil", ns, key) + } + } + } + + if ok := c.Evict(0, 1); !ok { + t.Error("first Cache.Evict on #0.1 return false") + } + if ok := c.Evict(0, 1); ok { + t.Error("second Cache.Evict on #0.1 return true") + } + if h := c.Get(0, 1, nil); h != nil { + t.Errorf("Cache.Get on #0.1 return non-nil: %v", h.Value()) + } + + c.EvictNS(1) + if h := c.Get(1, 1, nil); h != nil { + t.Errorf("Cache.Get on #1.1 return non-nil: %v", h.Value()) + } + if h := c.Get(1, 2, nil); h != nil { + t.Errorf("Cache.Get on #1.2 return non-nil: %v", h.Value()) + } + + c.EvictAll() + for ns := 0; ns < 3; ns++ { + for key := 1; key < 3; key++ { + if h := c.Get(uint64(ns), uint64(key), nil); h != nil { + t.Errorf("Cache.Get on #%d.%d return non-nil: %v", ns, key, h.Value()) + } + } + } +} + +func TestLRUCache_Delete(t *testing.T) { + delFuncCalled := 0 + delFunc := func() { + delFuncCalled++ + } + + c := NewCache(NewLRU(2)) + set(c, 0, 1, 1, 1, nil).Release() + set(c, 0, 2, 2, 1, nil).Release() + + if ok := c.Delete(0, 1, delFunc); !ok { + t.Error("Cache.Delete on #1 return false") + } + if h := c.Get(0, 1, nil); h != nil { + t.Errorf("Cache.Get on #1 return non-nil: %v", h.Value()) + } + if ok := c.Delete(0, 1, delFunc); ok { + t.Error("Cache.Delete on #1 return true") + } + + h2 := c.Get(0, 2, nil) + if h2 == nil { + t.Error("Cache.Get on #2 return nil") + } + if ok := c.Delete(0, 2, delFunc); !ok { + t.Error("(1) Cache.Delete on #2 return false") + } + if ok := c.Delete(0, 2, delFunc); !ok { + t.Error("(2) Cache.Delete on #2 return false") + } + + set(c, 0, 3, 3, 1, nil).Release() + set(c, 0, 4, 4, 1, nil).Release() + c.Get(0, 2, nil).Release() + + for key := 2; key <= 4; key++ { + if h := c.Get(0, uint64(key), nil); h != nil { + h.Release() + } else { + t.Errorf("Cache.Get on #%d return nil", key) + } + } + + h2.Release() + if h := c.Get(0, 2, nil); h != nil { + t.Errorf("Cache.Get on #2 return non-nil: %v", h.Value()) + } + + if delFuncCalled != 4 { + t.Errorf("delFunc isn't called 4 times: got=%d", delFuncCalled) + } +} + +func TestLRUCache_Close(t *testing.T) { + relFuncCalled := 0 + relFunc := func() { + relFuncCalled++ + } + delFuncCalled := 0 + delFunc := func() { + delFuncCalled++ + } + + c := NewCache(NewLRU(2)) + set(c, 0, 1, 1, 1, relFunc).Release() + set(c, 0, 2, 2, 1, relFunc).Release() + + h3 := set(c, 0, 3, 3, 1, relFunc) + if h3 == nil { + t.Error("Cache.Get on #3 return nil") + } + if ok := c.Delete(0, 3, delFunc); !ok { + t.Error("Cache.Delete on #3 return false") + } + + c.Close() + + if relFuncCalled != 3 { + t.Errorf("relFunc isn't called 3 times: got=%d", relFuncCalled) + } + if delFuncCalled != 1 { + t.Errorf("delFunc isn't called 1 times: got=%d", delFuncCalled) + } +} + +func BenchmarkLRUCache(b *testing.B) { + c := NewCache(NewLRU(10000)) + + b.SetParallelism(10) + b.RunParallel(func(pb *testing.PB) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for pb.Next() { + key := uint64(r.Intn(1000000)) + c.Get(0, key, func() (int, Value) { + return 1, key + }).Release() + } + }) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/lru.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/lru.go new file mode 100644 index 0000000000..d9a84cde15 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/cache/lru.go @@ -0,0 +1,195 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package cache + +import ( + "sync" + "unsafe" +) + +type lruNode struct { + n *Node + h *Handle + ban bool + + next, prev *lruNode +} + +func (n *lruNode) insert(at *lruNode) { + x := at.next + at.next = n + n.prev = at + n.next = x + x.prev = n +} + +func (n *lruNode) remove() { + if n.prev != nil { + n.prev.next = n.next + n.next.prev = n.prev + n.prev = nil + n.next = nil + } else { + panic("BUG: removing removed node") + } +} + +type lru struct { + mu sync.Mutex + capacity int + used int + recent lruNode +} + +func (r *lru) reset() { + r.recent.next = &r.recent + r.recent.prev = &r.recent + r.used = 0 +} + +func (r *lru) Capacity() int { + r.mu.Lock() + defer r.mu.Unlock() + return r.capacity +} + +func (r *lru) SetCapacity(capacity int) { + var evicted []*lruNode + + r.mu.Lock() + r.capacity = capacity + for r.used > r.capacity { + rn := r.recent.prev + if rn == nil { + panic("BUG: invalid LRU used or capacity counter") + } + rn.remove() + rn.n.CacheData = nil + r.used -= rn.n.Size() + evicted = append(evicted, rn) + } + r.mu.Unlock() + + for _, rn := range evicted { + rn.h.Release() + } +} + +func (r *lru) Promote(n *Node) { + var evicted []*lruNode + + r.mu.Lock() + if n.CacheData == nil { + if n.Size() <= r.capacity { + rn := &lruNode{n: n, h: n.GetHandle()} + rn.insert(&r.recent) + n.CacheData = unsafe.Pointer(rn) + r.used += n.Size() + + for r.used > r.capacity { + rn := r.recent.prev + if rn == nil { + panic("BUG: invalid LRU used or capacity counter") + } + rn.remove() + rn.n.CacheData = nil + r.used -= rn.n.Size() + evicted = append(evicted, rn) + } + } + } else { + rn := (*lruNode)(n.CacheData) + if !rn.ban { + rn.remove() + rn.insert(&r.recent) + } + } + r.mu.Unlock() + + for _, rn := range evicted { + rn.h.Release() + } +} + +func (r *lru) Ban(n *Node) { + r.mu.Lock() + if n.CacheData == nil { + n.CacheData = unsafe.Pointer(&lruNode{n: n, ban: true}) + } else { + rn := (*lruNode)(n.CacheData) + if !rn.ban { + rn.remove() + rn.ban = true + r.used -= rn.n.Size() + r.mu.Unlock() + + rn.h.Release() + rn.h = nil + return + } + } + r.mu.Unlock() +} + +func (r *lru) Evict(n *Node) { + r.mu.Lock() + rn := (*lruNode)(n.CacheData) + if rn == nil || rn.ban { + r.mu.Unlock() + return + } + n.CacheData = nil + r.mu.Unlock() + + rn.h.Release() +} + +func (r *lru) EvictNS(ns uint64) { + var evicted []*lruNode + + r.mu.Lock() + for e := r.recent.prev; e != &r.recent; { + rn := e + e = e.prev + if rn.n.NS() == ns { + rn.remove() + rn.n.CacheData = nil + r.used -= rn.n.Size() + evicted = append(evicted, rn) + } + } + r.mu.Unlock() + + for _, rn := range evicted { + rn.h.Release() + } +} + +func (r *lru) EvictAll() { + r.mu.Lock() + back := r.recent.prev + for rn := back; rn != &r.recent; rn = rn.prev { + rn.n.CacheData = nil + } + r.reset() + r.mu.Unlock() + + for rn := back; rn != &r.recent; rn = rn.prev { + rn.h.Release() + } +} + +func (r *lru) Close() error { + return nil +} + +// NewLRU create a new LRU-cache. +func NewLRU(capacity int) Cacher { + r := &lru{capacity: capacity} + r.reset() + return r +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer.go new file mode 100644 index 0000000000..d33d5e9c78 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer.go @@ -0,0 +1,75 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import "github.com/syndtr/goleveldb/leveldb/comparer" + +type iComparer struct { + ucmp comparer.Comparer +} + +func (icmp *iComparer) uName() string { + return icmp.ucmp.Name() +} + +func (icmp *iComparer) uCompare(a, b []byte) int { + return icmp.ucmp.Compare(a, b) +} + +func (icmp *iComparer) uSeparator(dst, a, b []byte) []byte { + return icmp.ucmp.Separator(dst, a, b) +} + +func (icmp *iComparer) uSuccessor(dst, b []byte) []byte { + return icmp.ucmp.Successor(dst, b) +} + +func (icmp *iComparer) Name() string { + return icmp.uName() +} + +func (icmp *iComparer) Compare(a, b []byte) int { + x := icmp.ucmp.Compare(iKey(a).ukey(), iKey(b).ukey()) + if x == 0 { + if m, n := iKey(a).num(), iKey(b).num(); m > n { + x = -1 + } else if m < n { + x = 1 + } + } + return x +} + +func (icmp *iComparer) Separator(dst, a, b []byte) []byte { + ua, ub := iKey(a).ukey(), iKey(b).ukey() + dst = icmp.ucmp.Separator(dst, ua, ub) + if dst == nil { + return nil + } + if len(dst) < len(ua) && icmp.uCompare(ua, dst) < 0 { + dst = append(dst, kMaxNumBytes...) + } else { + // Did not close possibilities that n maybe longer than len(ub). + dst = append(dst, a[len(a)-8:]...) + } + return dst +} + +func (icmp *iComparer) Successor(dst, b []byte) []byte { + ub := iKey(b).ukey() + dst = icmp.ucmp.Successor(dst, ub) + if dst == nil { + return nil + } + if len(dst) < len(ub) && icmp.uCompare(ub, dst) < 0 { + dst = append(dst, kMaxNumBytes...) + } else { + // Did not close possibilities that n maybe longer than len(ub). + dst = append(dst, b[len(b)-8:]...) + } + return dst +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go new file mode 100644 index 0000000000..14dddf88dd --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/bytes_comparer.go @@ -0,0 +1,51 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package comparer + +import "bytes" + +type bytesComparer struct{} + +func (bytesComparer) Compare(a, b []byte) int { + return bytes.Compare(a, b) +} + +func (bytesComparer) Name() string { + return "leveldb.BytewiseComparator" +} + +func (bytesComparer) Separator(dst, a, b []byte) []byte { + i, n := 0, len(a) + if n > len(b) { + n = len(b) + } + for ; i < n && a[i] == b[i]; i++ { + } + if i >= n { + // Do not shorten if one string is a prefix of the other + } else if c := a[i]; c < 0xff && c+1 < b[i] { + dst = append(dst, a[:i+1]...) + dst[i]++ + return dst + } + return nil +} + +func (bytesComparer) Successor(dst, b []byte) []byte { + for i, c := range b { + if c != 0xff { + dst = append(dst, b[:i+1]...) + dst[i]++ + return dst + } + } + return nil +} + +// DefaultComparer are default implementation of the Comparer interface. +// It uses the natural ordering, consistent with bytes.Compare. +var DefaultComparer = bytesComparer{} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go new file mode 100644 index 0000000000..14a28f16fc --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/comparer/comparer.go @@ -0,0 +1,57 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package comparer provides interface and implementation for ordering +// sets of data. +package comparer + +// BasicComparer is the interface that wraps the basic Compare method. +type BasicComparer interface { + // Compare returns -1, 0, or +1 depending on whether a is 'less than', + // 'equal to' or 'greater than' b. The two arguments can only be 'equal' + // if their contents are exactly equal. Furthermore, the empty slice + // must be 'less than' any non-empty slice. + Compare(a, b []byte) int +} + +// Comparer defines a total ordering over the space of []byte keys: a 'less +// than' relationship. +type Comparer interface { + BasicComparer + + // Name returns name of the comparer. + // + // The Level-DB on-disk format stores the comparer name, and opening a + // database with a different comparer from the one it was created with + // will result in an error. + // + // An implementation to a new name whenever the comparer implementation + // changes in a way that will cause the relative ordering of any two keys + // to change. + // + // Names starting with "leveldb." are reserved and should not be used + // by any users of this package. + Name() string + + // Bellow are advanced functions used used to reduce the space requirements + // for internal data structures such as index blocks. + + // Separator appends a sequence of bytes x to dst such that a <= x && x < b, + // where 'less than' is consistent with Compare. An implementation should + // return nil if x equal to a. + // + // Either contents of a or b should not by any means modified. Doing so + // may cause corruption on the internal state. + Separator(dst, a, b []byte) []byte + + // Successor appends a sequence of bytes x to dst such that x >= b, where + // 'less than' is consistent with Compare. An implementation should return + // nil if x equal to b. + // + // Contents of b should not by any means modified. Doing so may cause + // corruption on the internal state. + Successor(dst, b []byte) []byte +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/corrupt_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/corrupt_test.go new file mode 100644 index 0000000000..a351874ed4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/corrupt_test.go @@ -0,0 +1,500 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "fmt" + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "io" + "math/rand" + "testing" +) + +const ctValSize = 1000 + +type dbCorruptHarness struct { + dbHarness +} + +func newDbCorruptHarnessWopt(t *testing.T, o *opt.Options) *dbCorruptHarness { + h := new(dbCorruptHarness) + h.init(t, o) + return h +} + +func newDbCorruptHarness(t *testing.T) *dbCorruptHarness { + return newDbCorruptHarnessWopt(t, &opt.Options{ + BlockCacheCapacity: 100, + Strict: opt.StrictJournalChecksum, + }) +} + +func (h *dbCorruptHarness) recover() { + p := &h.dbHarness + t := p.t + + var err error + p.db, err = Recover(h.stor, h.o) + if err != nil { + t.Fatal("Repair: got error: ", err) + } +} + +func (h *dbCorruptHarness) build(n int) { + p := &h.dbHarness + t := p.t + db := p.db + + batch := new(Batch) + for i := 0; i < n; i++ { + batch.Reset() + batch.Put(tkey(i), tval(i, ctValSize)) + err := db.Write(batch, p.wo) + if err != nil { + t.Fatal("write error: ", err) + } + } +} + +func (h *dbCorruptHarness) buildShuffled(n int, rnd *rand.Rand) { + p := &h.dbHarness + t := p.t + db := p.db + + batch := new(Batch) + for i := range rnd.Perm(n) { + batch.Reset() + batch.Put(tkey(i), tval(i, ctValSize)) + err := db.Write(batch, p.wo) + if err != nil { + t.Fatal("write error: ", err) + } + } +} + +func (h *dbCorruptHarness) deleteRand(n, max int, rnd *rand.Rand) { + p := &h.dbHarness + t := p.t + db := p.db + + batch := new(Batch) + for i := 0; i < n; i++ { + batch.Reset() + batch.Delete(tkey(rnd.Intn(max))) + err := db.Write(batch, p.wo) + if err != nil { + t.Fatal("write error: ", err) + } + } +} + +func (h *dbCorruptHarness) corrupt(ft storage.FileType, fi, offset, n int) { + p := &h.dbHarness + t := p.t + + ff, _ := p.stor.GetFiles(ft) + sff := files(ff) + sff.sort() + if fi < 0 { + fi = len(sff) - 1 + } + if fi >= len(sff) { + t.Fatalf("no such file with type %q with index %d", ft, fi) + } + + file := sff[fi] + + r, err := file.Open() + if err != nil { + t.Fatal("cannot open file: ", err) + } + x, err := r.Seek(0, 2) + if err != nil { + t.Fatal("cannot query file size: ", err) + } + m := int(x) + if _, err := r.Seek(0, 0); err != nil { + t.Fatal(err) + } + + if offset < 0 { + if -offset > m { + offset = 0 + } else { + offset = m + offset + } + } + if offset > m { + offset = m + } + if offset+n > m { + n = m - offset + } + + buf := make([]byte, m) + _, err = io.ReadFull(r, buf) + if err != nil { + t.Fatal("cannot read file: ", err) + } + r.Close() + + for i := 0; i < n; i++ { + buf[offset+i] ^= 0x80 + } + + err = file.Remove() + if err != nil { + t.Fatal("cannot remove old file: ", err) + } + w, err := file.Create() + if err != nil { + t.Fatal("cannot create new file: ", err) + } + _, err = w.Write(buf) + if err != nil { + t.Fatal("cannot write new file: ", err) + } + w.Close() +} + +func (h *dbCorruptHarness) removeAll(ft storage.FileType) { + ff, err := h.stor.GetFiles(ft) + if err != nil { + h.t.Fatal("get files: ", err) + } + for _, f := range ff { + if err := f.Remove(); err != nil { + h.t.Error("remove file: ", err) + } + } +} + +func (h *dbCorruptHarness) removeOne(ft storage.FileType) { + ff, err := h.stor.GetFiles(ft) + if err != nil { + h.t.Fatal("get files: ", err) + } + f := ff[rand.Intn(len(ff))] + h.t.Logf("removing file @%d", f.Num()) + if err := f.Remove(); err != nil { + h.t.Error("remove file: ", err) + } +} + +func (h *dbCorruptHarness) check(min, max int) { + p := &h.dbHarness + t := p.t + db := p.db + + var n, badk, badv, missed, good int + iter := db.NewIterator(nil, p.ro) + for iter.Next() { + k := 0 + fmt.Sscanf(string(iter.Key()), "%d", &k) + if k < n { + badk++ + continue + } + missed += k - n + n = k + 1 + if !bytes.Equal(iter.Value(), tval(k, ctValSize)) { + badv++ + } else { + good++ + } + } + err := iter.Error() + iter.Release() + t.Logf("want=%d..%d got=%d badkeys=%d badvalues=%d missed=%d, err=%v", + min, max, good, badk, badv, missed, err) + if good < min || good > max { + t.Errorf("good entries number not in range") + } +} + +func TestCorruptDB_Journal(t *testing.T) { + h := newDbCorruptHarness(t) + + h.build(100) + h.check(100, 100) + h.closeDB() + h.corrupt(storage.TypeJournal, -1, 19, 1) + h.corrupt(storage.TypeJournal, -1, 32*1024+1000, 1) + + h.openDB() + h.check(36, 36) + + h.close() +} + +func TestCorruptDB_Table(t *testing.T) { + h := newDbCorruptHarness(t) + + h.build(100) + h.compactMem() + h.compactRangeAt(0, "", "") + h.compactRangeAt(1, "", "") + h.closeDB() + h.corrupt(storage.TypeTable, -1, 100, 1) + + h.openDB() + h.check(99, 99) + + h.close() +} + +func TestCorruptDB_TableIndex(t *testing.T) { + h := newDbCorruptHarness(t) + + h.build(10000) + h.compactMem() + h.closeDB() + h.corrupt(storage.TypeTable, -1, -2000, 500) + + h.openDB() + h.check(5000, 9999) + + h.close() +} + +func TestCorruptDB_MissingManifest(t *testing.T) { + rnd := rand.New(rand.NewSource(0x0badda7a)) + h := newDbCorruptHarnessWopt(t, &opt.Options{ + BlockCacheCapacity: 100, + Strict: opt.StrictJournalChecksum, + WriteBuffer: 1000 * 60, + }) + + h.build(1000) + h.compactMem() + h.buildShuffled(1000, rnd) + h.compactMem() + h.deleteRand(500, 1000, rnd) + h.compactMem() + h.buildShuffled(1000, rnd) + h.compactMem() + h.deleteRand(500, 1000, rnd) + h.compactMem() + h.buildShuffled(1000, rnd) + h.compactMem() + h.closeDB() + + h.stor.SetIgnoreOpenErr(storage.TypeManifest) + h.removeAll(storage.TypeManifest) + h.openAssert(false) + h.stor.SetIgnoreOpenErr(0) + + h.recover() + h.check(1000, 1000) + h.build(1000) + h.compactMem() + h.compactRange("", "") + h.closeDB() + + h.recover() + h.check(1000, 1000) + + h.close() +} + +func TestCorruptDB_SequenceNumberRecovery(t *testing.T) { + h := newDbCorruptHarness(t) + + h.put("foo", "v1") + h.put("foo", "v2") + h.put("foo", "v3") + h.put("foo", "v4") + h.put("foo", "v5") + h.closeDB() + + h.recover() + h.getVal("foo", "v5") + h.put("foo", "v6") + h.getVal("foo", "v6") + + h.reopenDB() + h.getVal("foo", "v6") + + h.close() +} + +func TestCorruptDB_SequenceNumberRecoveryTable(t *testing.T) { + h := newDbCorruptHarness(t) + + h.put("foo", "v1") + h.put("foo", "v2") + h.put("foo", "v3") + h.compactMem() + h.put("foo", "v4") + h.put("foo", "v5") + h.compactMem() + h.closeDB() + + h.recover() + h.getVal("foo", "v5") + h.put("foo", "v6") + h.getVal("foo", "v6") + + h.reopenDB() + h.getVal("foo", "v6") + + h.close() +} + +func TestCorruptDB_CorruptedManifest(t *testing.T) { + h := newDbCorruptHarness(t) + + h.put("foo", "hello") + h.compactMem() + h.compactRange("", "") + h.closeDB() + h.corrupt(storage.TypeManifest, -1, 0, 1000) + h.openAssert(false) + + h.recover() + h.getVal("foo", "hello") + + h.close() +} + +func TestCorruptDB_CompactionInputError(t *testing.T) { + h := newDbCorruptHarness(t) + + h.build(10) + h.compactMem() + h.closeDB() + h.corrupt(storage.TypeTable, -1, 100, 1) + + h.openDB() + h.check(9, 9) + + h.build(10000) + h.check(10000, 10000) + + h.close() +} + +func TestCorruptDB_UnrelatedKeys(t *testing.T) { + h := newDbCorruptHarness(t) + + h.build(10) + h.compactMem() + h.closeDB() + h.corrupt(storage.TypeTable, -1, 100, 1) + + h.openDB() + h.put(string(tkey(1000)), string(tval(1000, ctValSize))) + h.getVal(string(tkey(1000)), string(tval(1000, ctValSize))) + h.compactMem() + h.getVal(string(tkey(1000)), string(tval(1000, ctValSize))) + + h.close() +} + +func TestCorruptDB_Level0NewerFileHasOlderSeqnum(t *testing.T) { + h := newDbCorruptHarness(t) + + h.put("a", "v1") + h.put("b", "v1") + h.compactMem() + h.put("a", "v2") + h.put("b", "v2") + h.compactMem() + h.put("a", "v3") + h.put("b", "v3") + h.compactMem() + h.put("c", "v0") + h.put("d", "v0") + h.compactMem() + h.compactRangeAt(1, "", "") + h.closeDB() + + h.recover() + h.getVal("a", "v3") + h.getVal("b", "v3") + h.getVal("c", "v0") + h.getVal("d", "v0") + + h.close() +} + +func TestCorruptDB_RecoverInvalidSeq_Issue53(t *testing.T) { + h := newDbCorruptHarness(t) + + h.put("a", "v1") + h.put("b", "v1") + h.compactMem() + h.put("a", "v2") + h.put("b", "v2") + h.compactMem() + h.put("a", "v3") + h.put("b", "v3") + h.compactMem() + h.put("c", "v0") + h.put("d", "v0") + h.compactMem() + h.compactRangeAt(0, "", "") + h.closeDB() + + h.recover() + h.getVal("a", "v3") + h.getVal("b", "v3") + h.getVal("c", "v0") + h.getVal("d", "v0") + + h.close() +} + +func TestCorruptDB_MissingTableFiles(t *testing.T) { + h := newDbCorruptHarness(t) + + h.put("a", "v1") + h.put("b", "v1") + h.compactMem() + h.put("c", "v2") + h.put("d", "v2") + h.compactMem() + h.put("e", "v3") + h.put("f", "v3") + h.closeDB() + + h.removeOne(storage.TypeTable) + h.openAssert(false) + + h.close() +} + +func TestCorruptDB_RecoverTable(t *testing.T) { + h := newDbCorruptHarnessWopt(t, &opt.Options{ + WriteBuffer: 112 * opt.KiB, + CompactionTableSize: 90 * opt.KiB, + Filter: filter.NewBloomFilter(10), + }) + + h.build(1000) + h.compactMem() + h.compactRangeAt(0, "", "") + h.compactRangeAt(1, "", "") + seq := h.db.seq + h.closeDB() + h.corrupt(storage.TypeTable, 0, 1000, 1) + h.corrupt(storage.TypeTable, 3, 10000, 1) + // Corrupted filter shouldn't affect recovery. + h.corrupt(storage.TypeTable, 3, 113888, 10) + h.corrupt(storage.TypeTable, -1, 20000, 1) + + h.recover() + if h.db.seq != seq { + t.Errorf("invalid seq, want=%d got=%d", seq, h.db.seq) + } + h.check(985, 985) + + h.close() +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db.go new file mode 100644 index 0000000000..d50a00856d --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db.go @@ -0,0 +1,943 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "container/list" + "fmt" + "io" + "os" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/table" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// DB is a LevelDB database. +type DB struct { + // Need 64-bit alignment. + seq uint64 + + // Session. + s *session + + // MemDB. + memMu sync.RWMutex + memPool chan *memdb.DB + mem, frozenMem *memDB + journal *journal.Writer + journalWriter storage.Writer + journalFile storage.File + frozenJournalFile storage.File + frozenSeq uint64 + + // Snapshot. + snapsMu sync.Mutex + snapsList *list.List + + // Stats. + aliveSnaps, aliveIters int32 + + // Write. + writeC chan *Batch + writeMergedC chan bool + writeLockC chan struct{} + writeAckC chan error + writeDelay time.Duration + writeDelayN int + journalC chan *Batch + journalAckC chan error + + // Compaction. + tcompCmdC chan cCmd + tcompPauseC chan chan<- struct{} + mcompCmdC chan cCmd + compErrC chan error + compPerErrC chan error + compErrSetC chan error + compStats []cStats + + // Close. + closeW sync.WaitGroup + closeC chan struct{} + closed uint32 + closer io.Closer +} + +func openDB(s *session) (*DB, error) { + s.log("db@open opening") + start := time.Now() + db := &DB{ + s: s, + // Initial sequence + seq: s.stSeqNum, + // MemDB + memPool: make(chan *memdb.DB, 1), + // Snapshot + snapsList: list.New(), + // Write + writeC: make(chan *Batch), + writeMergedC: make(chan bool), + writeLockC: make(chan struct{}, 1), + writeAckC: make(chan error), + journalC: make(chan *Batch), + journalAckC: make(chan error), + // Compaction + tcompCmdC: make(chan cCmd), + tcompPauseC: make(chan chan<- struct{}), + mcompCmdC: make(chan cCmd), + compErrC: make(chan error), + compPerErrC: make(chan error), + compErrSetC: make(chan error), + compStats: make([]cStats, s.o.GetNumLevel()), + // Close + closeC: make(chan struct{}), + } + + if err := db.recoverJournal(); err != nil { + return nil, err + } + + // Remove any obsolete files. + if err := db.checkAndCleanFiles(); err != nil { + // Close journal. + if db.journal != nil { + db.journal.Close() + db.journalWriter.Close() + } + return nil, err + } + + // Doesn't need to be included in the wait group. + go db.compactionError() + go db.mpoolDrain() + + db.closeW.Add(3) + go db.tCompaction() + go db.mCompaction() + go db.jWriter() + + s.logf("db@open done T·%v", time.Since(start)) + + runtime.SetFinalizer(db, (*DB).Close) + return db, nil +} + +// Open opens or creates a DB for the given storage. +// The DB will be created if not exist, unless ErrorIfMissing is true. +// Also, if ErrorIfExist is true and the DB exist Open will returns +// os.ErrExist error. +// +// Open will return an error with type of ErrCorrupted if corruption +// detected in the DB. Corrupted DB can be recovered with Recover +// function. +// +// The returned DB instance is goroutine-safe. +// The DB must be closed after use, by calling Close method. +func Open(stor storage.Storage, o *opt.Options) (db *DB, err error) { + s, err := newSession(stor, o) + if err != nil { + return + } + defer func() { + if err != nil { + s.close() + s.release() + } + }() + + err = s.recover() + if err != nil { + if !os.IsNotExist(err) || s.o.GetErrorIfMissing() { + return + } + err = s.create() + if err != nil { + return + } + } else if s.o.GetErrorIfExist() { + err = os.ErrExist + return + } + + return openDB(s) +} + +// OpenFile opens or creates a DB for the given path. +// The DB will be created if not exist, unless ErrorIfMissing is true. +// Also, if ErrorIfExist is true and the DB exist OpenFile will returns +// os.ErrExist error. +// +// OpenFile uses standard file-system backed storage implementation as +// desribed in the leveldb/storage package. +// +// OpenFile will return an error with type of ErrCorrupted if corruption +// detected in the DB. Corrupted DB can be recovered with Recover +// function. +// +// The returned DB instance is goroutine-safe. +// The DB must be closed after use, by calling Close method. +func OpenFile(path string, o *opt.Options) (db *DB, err error) { + stor, err := storage.OpenFile(path) + if err != nil { + return + } + db, err = Open(stor, o) + if err != nil { + stor.Close() + } else { + db.closer = stor + } + return +} + +// Recover recovers and opens a DB with missing or corrupted manifest files +// for the given storage. It will ignore any manifest files, valid or not. +// The DB must already exist or it will returns an error. +// Also, Recover will ignore ErrorIfMissing and ErrorIfExist options. +// +// The returned DB instance is goroutine-safe. +// The DB must be closed after use, by calling Close method. +func Recover(stor storage.Storage, o *opt.Options) (db *DB, err error) { + s, err := newSession(stor, o) + if err != nil { + return + } + defer func() { + if err != nil { + s.close() + s.release() + } + }() + + err = recoverTable(s, o) + if err != nil { + return + } + return openDB(s) +} + +// RecoverFile recovers and opens a DB with missing or corrupted manifest files +// for the given path. It will ignore any manifest files, valid or not. +// The DB must already exist or it will returns an error. +// Also, Recover will ignore ErrorIfMissing and ErrorIfExist options. +// +// RecoverFile uses standard file-system backed storage implementation as desribed +// in the leveldb/storage package. +// +// The returned DB instance is goroutine-safe. +// The DB must be closed after use, by calling Close method. +func RecoverFile(path string, o *opt.Options) (db *DB, err error) { + stor, err := storage.OpenFile(path) + if err != nil { + return + } + db, err = Recover(stor, o) + if err != nil { + stor.Close() + } else { + db.closer = stor + } + return +} + +func recoverTable(s *session, o *opt.Options) error { + o = dupOptions(o) + // Mask StrictReader, lets StrictRecovery doing its job. + o.Strict &= ^opt.StrictReader + + // Get all tables and sort it by file number. + tableFiles_, err := s.getFiles(storage.TypeTable) + if err != nil { + return err + } + tableFiles := files(tableFiles_) + tableFiles.sort() + + var ( + maxSeq uint64 + recoveredKey, goodKey, corruptedKey, corruptedBlock, droppedTable int + + // We will drop corrupted table. + strict = o.GetStrict(opt.StrictRecovery) + + rec = &sessionRecord{numLevel: o.GetNumLevel()} + bpool = util.NewBufferPool(o.GetBlockSize() + 5) + ) + buildTable := func(iter iterator.Iterator) (tmp storage.File, size int64, err error) { + tmp = s.newTemp() + writer, err := tmp.Create() + if err != nil { + return + } + defer func() { + writer.Close() + if err != nil { + tmp.Remove() + tmp = nil + } + }() + + // Copy entries. + tw := table.NewWriter(writer, o) + for iter.Next() { + key := iter.Key() + if validIkey(key) { + err = tw.Append(key, iter.Value()) + if err != nil { + return + } + } + } + err = iter.Error() + if err != nil { + return + } + err = tw.Close() + if err != nil { + return + } + err = writer.Sync() + if err != nil { + return + } + size = int64(tw.BytesLen()) + return + } + recoverTable := func(file storage.File) error { + s.logf("table@recovery recovering @%d", file.Num()) + reader, err := file.Open() + if err != nil { + return err + } + var closed bool + defer func() { + if !closed { + reader.Close() + } + }() + + // Get file size. + size, err := reader.Seek(0, 2) + if err != nil { + return err + } + + var ( + tSeq uint64 + tgoodKey, tcorruptedKey, tcorruptedBlock int + imin, imax []byte + ) + tr, err := table.NewReader(reader, size, storage.NewFileInfo(file), nil, bpool, o) + if err != nil { + return err + } + iter := tr.NewIterator(nil, nil) + iter.(iterator.ErrorCallbackSetter).SetErrorCallback(func(err error) { + if errors.IsCorrupted(err) { + s.logf("table@recovery block corruption @%d %q", file.Num(), err) + tcorruptedBlock++ + } + }) + + // Scan the table. + for iter.Next() { + key := iter.Key() + _, seq, _, kerr := parseIkey(key) + if kerr != nil { + tcorruptedKey++ + continue + } + tgoodKey++ + if seq > tSeq { + tSeq = seq + } + if imin == nil { + imin = append([]byte{}, key...) + } + imax = append(imax[:0], key...) + } + if err := iter.Error(); err != nil { + iter.Release() + return err + } + iter.Release() + + goodKey += tgoodKey + corruptedKey += tcorruptedKey + corruptedBlock += tcorruptedBlock + + if strict && (tcorruptedKey > 0 || tcorruptedBlock > 0) { + droppedTable++ + s.logf("table@recovery dropped @%d Gk·%d Ck·%d Cb·%d S·%d Q·%d", file.Num(), tgoodKey, tcorruptedKey, tcorruptedBlock, size, tSeq) + return nil + } + + if tgoodKey > 0 { + if tcorruptedKey > 0 || tcorruptedBlock > 0 { + // Rebuild the table. + s.logf("table@recovery rebuilding @%d", file.Num()) + iter := tr.NewIterator(nil, nil) + tmp, newSize, err := buildTable(iter) + iter.Release() + if err != nil { + return err + } + closed = true + reader.Close() + if err := file.Replace(tmp); err != nil { + return err + } + size = newSize + } + if tSeq > maxSeq { + maxSeq = tSeq + } + recoveredKey += tgoodKey + // Add table to level 0. + rec.addTable(0, file.Num(), uint64(size), imin, imax) + s.logf("table@recovery recovered @%d Gk·%d Ck·%d Cb·%d S·%d Q·%d", file.Num(), tgoodKey, tcorruptedKey, tcorruptedBlock, size, tSeq) + } else { + droppedTable++ + s.logf("table@recovery unrecoverable @%d Ck·%d Cb·%d S·%d", file.Num(), tcorruptedKey, tcorruptedBlock, size) + } + + return nil + } + + // Recover all tables. + if len(tableFiles) > 0 { + s.logf("table@recovery F·%d", len(tableFiles)) + + // Mark file number as used. + s.markFileNum(tableFiles[len(tableFiles)-1].Num()) + + for _, file := range tableFiles { + if err := recoverTable(file); err != nil { + return err + } + } + + s.logf("table@recovery recovered F·%d N·%d Gk·%d Ck·%d Q·%d", len(tableFiles), recoveredKey, goodKey, corruptedKey, maxSeq) + } + + // Set sequence number. + rec.setSeqNum(maxSeq) + + // Create new manifest. + if err := s.create(); err != nil { + return err + } + + // Commit. + return s.commit(rec) +} + +func (db *DB) recoverJournal() error { + // Get all tables and sort it by file number. + journalFiles_, err := db.s.getFiles(storage.TypeJournal) + if err != nil { + return err + } + journalFiles := files(journalFiles_) + journalFiles.sort() + + // Discard older journal. + prev := -1 + for i, file := range journalFiles { + if file.Num() >= db.s.stJournalNum { + if prev >= 0 { + i-- + journalFiles[i] = journalFiles[prev] + } + journalFiles = journalFiles[i:] + break + } else if file.Num() == db.s.stPrevJournalNum { + prev = i + } + } + + var jr *journal.Reader + var of storage.File + var mem *memdb.DB + batch := new(Batch) + cm := newCMem(db.s) + buf := new(util.Buffer) + // Options. + strict := db.s.o.GetStrict(opt.StrictJournal) + checksum := db.s.o.GetStrict(opt.StrictJournalChecksum) + writeBuffer := db.s.o.GetWriteBuffer() + recoverJournal := func(file storage.File) error { + db.logf("journal@recovery recovering @%d", file.Num()) + reader, err := file.Open() + if err != nil { + return err + } + defer reader.Close() + + // Create/reset journal reader instance. + if jr == nil { + jr = journal.NewReader(reader, dropper{db.s, file}, strict, checksum) + } else { + jr.Reset(reader, dropper{db.s, file}, strict, checksum) + } + + // Flush memdb and remove obsolete journal file. + if of != nil { + if mem.Len() > 0 { + if err := cm.flush(mem, 0); err != nil { + return err + } + } + if err := cm.commit(file.Num(), db.seq); err != nil { + return err + } + cm.reset() + of.Remove() + of = nil + } + + // Replay journal to memdb. + mem.Reset() + for { + r, err := jr.Next() + if err != nil { + if err == io.EOF { + break + } + return errors.SetFile(err, file) + } + + buf.Reset() + if _, err := buf.ReadFrom(r); err != nil { + if err == io.ErrUnexpectedEOF { + // This is error returned due to corruption, with strict == false. + continue + } else { + return errors.SetFile(err, file) + } + } + if err := batch.memDecodeAndReplay(db.seq, buf.Bytes(), mem); err != nil { + if strict || !errors.IsCorrupted(err) { + return errors.SetFile(err, file) + } else { + db.s.logf("journal error: %v (skipped)", err) + // We won't apply sequence number as it might be corrupted. + continue + } + } + + // Save sequence number. + db.seq = batch.seq + uint64(batch.Len()) + + // Flush it if large enough. + if mem.Size() >= writeBuffer { + if err := cm.flush(mem, 0); err != nil { + return err + } + mem.Reset() + } + } + + of = file + return nil + } + + // Recover all journals. + if len(journalFiles) > 0 { + db.logf("journal@recovery F·%d", len(journalFiles)) + + // Mark file number as used. + db.s.markFileNum(journalFiles[len(journalFiles)-1].Num()) + + mem = memdb.New(db.s.icmp, writeBuffer) + for _, file := range journalFiles { + if err := recoverJournal(file); err != nil { + return err + } + } + + // Flush the last journal. + if mem.Len() > 0 { + if err := cm.flush(mem, 0); err != nil { + return err + } + } + } + + // Create a new journal. + if _, err := db.newMem(0); err != nil { + return err + } + + // Commit. + if err := cm.commit(db.journalFile.Num(), db.seq); err != nil { + // Close journal. + if db.journal != nil { + db.journal.Close() + db.journalWriter.Close() + } + return err + } + + // Remove the last obsolete journal file. + if of != nil { + of.Remove() + } + + return nil +} + +func (db *DB) get(key []byte, seq uint64, ro *opt.ReadOptions) (value []byte, err error) { + ikey := newIkey(key, seq, ktSeek) + + em, fm := db.getMems() + for _, m := range [...]*memDB{em, fm} { + if m == nil { + continue + } + defer m.decref() + + mk, mv, me := m.mdb.Find(ikey) + if me == nil { + ukey, _, kt, kerr := parseIkey(mk) + if kerr != nil { + // Shouldn't have had happen. + panic(kerr) + } + if db.s.icmp.uCompare(ukey, key) == 0 { + if kt == ktDel { + return nil, ErrNotFound + } + return append([]byte{}, mv...), nil + } + } else if me != ErrNotFound { + return nil, me + } + } + + v := db.s.version() + value, cSched, err := v.get(ikey, ro, false) + v.release() + if cSched { + // Trigger table compaction. + db.compSendTrigger(db.tcompCmdC) + } + return +} + +func (db *DB) has(key []byte, seq uint64, ro *opt.ReadOptions) (ret bool, err error) { + ikey := newIkey(key, seq, ktSeek) + + em, fm := db.getMems() + for _, m := range [...]*memDB{em, fm} { + if m == nil { + continue + } + defer m.decref() + + mk, _, me := m.mdb.Find(ikey) + if me == nil { + ukey, _, kt, kerr := parseIkey(mk) + if kerr != nil { + // Shouldn't have had happen. + panic(kerr) + } + if db.s.icmp.uCompare(ukey, key) == 0 { + if kt == ktDel { + return false, nil + } + return true, nil + } + } else if me != ErrNotFound { + return false, me + } + } + + v := db.s.version() + _, cSched, err := v.get(ikey, ro, true) + v.release() + if cSched { + // Trigger table compaction. + db.compSendTrigger(db.tcompCmdC) + } + if err == nil { + ret = true + } else if err == ErrNotFound { + err = nil + } + return +} + +// Get gets the value for the given key. It returns ErrNotFound if the +// DB does not contains the key. +// +// The returned slice is its own copy, it is safe to modify the contents +// of the returned slice. +// It is safe to modify the contents of the argument after Get returns. +func (db *DB) Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) { + err = db.ok() + if err != nil { + return + } + + se := db.acquireSnapshot() + defer db.releaseSnapshot(se) + return db.get(key, se.seq, ro) +} + +// Has returns true if the DB does contains the given key. +// +// It is safe to modify the contents of the argument after Get returns. +func (db *DB) Has(key []byte, ro *opt.ReadOptions) (ret bool, err error) { + err = db.ok() + if err != nil { + return + } + + se := db.acquireSnapshot() + defer db.releaseSnapshot(se) + return db.has(key, se.seq, ro) +} + +// NewIterator returns an iterator for the latest snapshot of the +// uderlying DB. +// The returned iterator is not goroutine-safe, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +// It is also safe to use an iterator concurrently with modifying its +// underlying DB. The resultant key/value pairs are guaranteed to be +// consistent. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// DB. And a nil Range.Limit is treated as a key after all keys in +// the DB. +// +// The iterator must be released after use, by calling Release method. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (db *DB) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + if err := db.ok(); err != nil { + return iterator.NewEmptyIterator(err) + } + + se := db.acquireSnapshot() + defer db.releaseSnapshot(se) + // Iterator holds 'version' lock, 'version' is immutable so snapshot + // can be released after iterator created. + return db.newIterator(se.seq, slice, ro) +} + +// GetSnapshot returns a latest snapshot of the underlying DB. A snapshot +// is a frozen snapshot of a DB state at a particular point in time. The +// content of snapshot are guaranteed to be consistent. +// +// The snapshot must be released after use, by calling Release method. +func (db *DB) GetSnapshot() (*Snapshot, error) { + if err := db.ok(); err != nil { + return nil, err + } + + return db.newSnapshot(), nil +} + +// GetProperty returns value of the given property name. +// +// Property names: +// leveldb.num-files-at-level{n} +// Returns the number of files at level 'n'. +// leveldb.stats +// Returns statistics of the underlying DB. +// leveldb.sstables +// Returns sstables list for each level. +// leveldb.blockpool +// Returns block pool stats. +// leveldb.cachedblock +// Returns size of cached block. +// leveldb.openedtables +// Returns number of opened tables. +// leveldb.alivesnaps +// Returns number of alive snapshots. +// leveldb.aliveiters +// Returns number of alive iterators. +func (db *DB) GetProperty(name string) (value string, err error) { + err = db.ok() + if err != nil { + return + } + + const prefix = "leveldb." + if !strings.HasPrefix(name, prefix) { + return "", errors.New("leveldb: GetProperty: unknown property: " + name) + } + p := name[len(prefix):] + + v := db.s.version() + defer v.release() + + numFilesPrefix := "num-files-at-level" + switch { + case strings.HasPrefix(p, numFilesPrefix): + var level uint + var rest string + n, _ := fmt.Sscanf(p[len(numFilesPrefix):], "%d%s", &level, &rest) + if n != 1 || int(level) >= db.s.o.GetNumLevel() { + err = errors.New("leveldb: GetProperty: invalid property: " + name) + } else { + value = fmt.Sprint(v.tLen(int(level))) + } + case p == "stats": + value = "Compactions\n" + + " Level | Tables | Size(MB) | Time(sec) | Read(MB) | Write(MB)\n" + + "-------+------------+---------------+---------------+---------------+---------------\n" + for level, tables := range v.tables { + duration, read, write := db.compStats[level].get() + if len(tables) == 0 && duration == 0 { + continue + } + value += fmt.Sprintf(" %3d | %10d | %13.5f | %13.5f | %13.5f | %13.5f\n", + level, len(tables), float64(tables.size())/1048576.0, duration.Seconds(), + float64(read)/1048576.0, float64(write)/1048576.0) + } + case p == "sstables": + for level, tables := range v.tables { + value += fmt.Sprintf("--- level %d ---\n", level) + for _, t := range tables { + value += fmt.Sprintf("%d:%d[%q .. %q]\n", t.file.Num(), t.size, t.imin, t.imax) + } + } + case p == "blockpool": + value = fmt.Sprintf("%v", db.s.tops.bpool) + case p == "cachedblock": + if db.s.tops.bcache != nil { + value = fmt.Sprintf("%d", db.s.tops.bcache.Size()) + } else { + value = "" + } + case p == "openedtables": + value = fmt.Sprintf("%d", db.s.tops.cache.Size()) + case p == "alivesnaps": + value = fmt.Sprintf("%d", atomic.LoadInt32(&db.aliveSnaps)) + case p == "aliveiters": + value = fmt.Sprintf("%d", atomic.LoadInt32(&db.aliveIters)) + default: + err = errors.New("leveldb: GetProperty: unknown property: " + name) + } + + return +} + +// SizeOf calculates approximate sizes of the given key ranges. +// The length of the returned sizes are equal with the length of the given +// ranges. The returned sizes measure storage space usage, so if the user +// data compresses by a factor of ten, the returned sizes will be one-tenth +// the size of the corresponding user data size. +// The results may not include the sizes of recently written data. +func (db *DB) SizeOf(ranges []util.Range) (Sizes, error) { + if err := db.ok(); err != nil { + return nil, err + } + + v := db.s.version() + defer v.release() + + sizes := make(Sizes, 0, len(ranges)) + for _, r := range ranges { + imin := newIkey(r.Start, kMaxSeq, ktSeek) + imax := newIkey(r.Limit, kMaxSeq, ktSeek) + start, err := v.offsetOf(imin) + if err != nil { + return nil, err + } + limit, err := v.offsetOf(imax) + if err != nil { + return nil, err + } + var size uint64 + if limit >= start { + size = limit - start + } + sizes = append(sizes, size) + } + + return sizes, nil +} + +// Close closes the DB. This will also releases any outstanding snapshot and +// abort any in-flight compaction. +// +// It is not safe to close a DB until all outstanding iterators are released. +// It is valid to call Close multiple times. Other methods should not be +// called after the DB has been closed. +func (db *DB) Close() error { + if !db.setClosed() { + return ErrClosed + } + + start := time.Now() + db.log("db@close closing") + + // Clear the finalizer. + runtime.SetFinalizer(db, nil) + + // Get compaction error. + var err error + select { + case err = <-db.compErrC: + default: + } + + // Signal all goroutines. + close(db.closeC) + + // Wait for all gorotines to exit. + db.closeW.Wait() + + // Lock writer and closes journal. + db.writeLockC <- struct{}{} + if db.journal != nil { + db.journal.Close() + db.journalWriter.Close() + } + + if db.writeDelayN > 0 { + db.logf("db@write was delayed N·%d T·%v", db.writeDelayN, db.writeDelay) + } + + // Close session. + db.s.close() + db.logf("db@close done T·%v", time.Since(start)) + db.s.release() + + if db.closer != nil { + if err1 := db.closer.Close(); err == nil { + err = err1 + } + } + + // NIL'ing pointers. + db.s = nil + db.mem = nil + db.frozenMem = nil + db.journal = nil + db.journalWriter = nil + db.journalFile = nil + db.frozenJournalFile = nil + db.closer = nil + + return err +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_compaction.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_compaction.go new file mode 100644 index 0000000000..447407abae --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_compaction.go @@ -0,0 +1,835 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "sync" + "time" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/opt" +) + +var ( + errCompactionTransactExiting = errors.New("leveldb: compaction transact exiting") +) + +type cStats struct { + sync.Mutex + duration time.Duration + read uint64 + write uint64 +} + +func (p *cStats) add(n *cStatsStaging) { + p.Lock() + p.duration += n.duration + p.read += n.read + p.write += n.write + p.Unlock() +} + +func (p *cStats) get() (duration time.Duration, read, write uint64) { + p.Lock() + defer p.Unlock() + return p.duration, p.read, p.write +} + +type cStatsStaging struct { + start time.Time + duration time.Duration + on bool + read uint64 + write uint64 +} + +func (p *cStatsStaging) startTimer() { + if !p.on { + p.start = time.Now() + p.on = true + } +} + +func (p *cStatsStaging) stopTimer() { + if p.on { + p.duration += time.Since(p.start) + p.on = false + } +} + +type cMem struct { + s *session + level int + rec *sessionRecord +} + +func newCMem(s *session) *cMem { + return &cMem{s: s, rec: &sessionRecord{numLevel: s.o.GetNumLevel()}} +} + +func (c *cMem) flush(mem *memdb.DB, level int) error { + s := c.s + + // Write memdb to table. + iter := mem.NewIterator(nil) + defer iter.Release() + t, n, err := s.tops.createFrom(iter) + if err != nil { + return err + } + + // Pick level. + if level < 0 { + v := s.version() + level = v.pickLevel(t.imin.ukey(), t.imax.ukey()) + v.release() + } + c.rec.addTableFile(level, t) + + s.logf("mem@flush created L%d@%d N·%d S·%s %q:%q", level, t.file.Num(), n, shortenb(int(t.size)), t.imin, t.imax) + + c.level = level + return nil +} + +func (c *cMem) reset() { + c.rec = &sessionRecord{numLevel: c.s.o.GetNumLevel()} +} + +func (c *cMem) commit(journal, seq uint64) error { + c.rec.setJournalNum(journal) + c.rec.setSeqNum(seq) + + // Commit changes. + return c.s.commit(c.rec) +} + +func (db *DB) compactionError() { + var ( + err error + wlocked bool + ) +noerr: + // No error. + for { + select { + case err = <-db.compErrSetC: + switch { + case err == nil: + case errors.IsCorrupted(err): + goto hasperr + default: + goto haserr + } + case _, _ = <-db.closeC: + return + } + } +haserr: + // Transient error. + for { + select { + case db.compErrC <- err: + case err = <-db.compErrSetC: + switch { + case err == nil: + goto noerr + case errors.IsCorrupted(err): + goto hasperr + default: + } + case _, _ = <-db.closeC: + return + } + } +hasperr: + // Persistent error. + for { + select { + case db.compErrC <- err: + case db.compPerErrC <- err: + case db.writeLockC <- struct{}{}: + // Hold write lock, so that write won't pass-through. + wlocked = true + case _, _ = <-db.closeC: + if wlocked { + // We should release the lock or Close will hang. + <-db.writeLockC + } + return + } + } +} + +type compactionTransactCounter int + +func (cnt *compactionTransactCounter) incr() { + *cnt++ +} + +type compactionTransactInterface interface { + run(cnt *compactionTransactCounter) error + revert() error +} + +func (db *DB) compactionTransact(name string, t compactionTransactInterface) { + defer func() { + if x := recover(); x != nil { + if x == errCompactionTransactExiting { + if err := t.revert(); err != nil { + db.logf("%s revert error %q", name, err) + } + } + panic(x) + } + }() + + const ( + backoffMin = 1 * time.Second + backoffMax = 8 * time.Second + backoffMul = 2 * time.Second + ) + var ( + backoff = backoffMin + backoffT = time.NewTimer(backoff) + lastCnt = compactionTransactCounter(0) + + disableBackoff = db.s.o.GetDisableCompactionBackoff() + ) + for n := 0; ; n++ { + // Check wether the DB is closed. + if db.isClosed() { + db.logf("%s exiting", name) + db.compactionExitTransact() + } else if n > 0 { + db.logf("%s retrying N·%d", name, n) + } + + // Execute. + cnt := compactionTransactCounter(0) + err := t.run(&cnt) + if err != nil { + db.logf("%s error I·%d %q", name, cnt, err) + } + + // Set compaction error status. + select { + case db.compErrSetC <- err: + case perr := <-db.compPerErrC: + if err != nil { + db.logf("%s exiting (persistent error %q)", name, perr) + db.compactionExitTransact() + } + case _, _ = <-db.closeC: + db.logf("%s exiting", name) + db.compactionExitTransact() + } + if err == nil { + return + } + if errors.IsCorrupted(err) { + db.logf("%s exiting (corruption detected)", name) + db.compactionExitTransact() + } + + if !disableBackoff { + // Reset backoff duration if counter is advancing. + if cnt > lastCnt { + backoff = backoffMin + lastCnt = cnt + } + + // Backoff. + backoffT.Reset(backoff) + if backoff < backoffMax { + backoff *= backoffMul + if backoff > backoffMax { + backoff = backoffMax + } + } + select { + case <-backoffT.C: + case _, _ = <-db.closeC: + db.logf("%s exiting", name) + db.compactionExitTransact() + } + } + } +} + +type compactionTransactFunc struct { + runFunc func(cnt *compactionTransactCounter) error + revertFunc func() error +} + +func (t *compactionTransactFunc) run(cnt *compactionTransactCounter) error { + return t.runFunc(cnt) +} + +func (t *compactionTransactFunc) revert() error { + if t.revertFunc != nil { + return t.revertFunc() + } + return nil +} + +func (db *DB) compactionTransactFunc(name string, run func(cnt *compactionTransactCounter) error, revert func() error) { + db.compactionTransact(name, &compactionTransactFunc{run, revert}) +} + +func (db *DB) compactionExitTransact() { + panic(errCompactionTransactExiting) +} + +func (db *DB) memCompaction() { + mem := db.getFrozenMem() + if mem == nil { + return + } + defer mem.decref() + + c := newCMem(db.s) + stats := new(cStatsStaging) + + db.logf("mem@flush N·%d S·%s", mem.mdb.Len(), shortenb(mem.mdb.Size())) + + // Don't compact empty memdb. + if mem.mdb.Len() == 0 { + db.logf("mem@flush skipping") + // drop frozen mem + db.dropFrozenMem() + return + } + + // Pause table compaction. + resumeC := make(chan struct{}) + select { + case db.tcompPauseC <- (chan<- struct{})(resumeC): + case <-db.compPerErrC: + close(resumeC) + resumeC = nil + case _, _ = <-db.closeC: + return + } + + db.compactionTransactFunc("mem@flush", func(cnt *compactionTransactCounter) (err error) { + stats.startTimer() + defer stats.stopTimer() + return c.flush(mem.mdb, -1) + }, func() error { + for _, r := range c.rec.addedTables { + db.logf("mem@flush revert @%d", r.num) + f := db.s.getTableFile(r.num) + if err := f.Remove(); err != nil { + return err + } + } + return nil + }) + + db.compactionTransactFunc("mem@commit", func(cnt *compactionTransactCounter) (err error) { + stats.startTimer() + defer stats.stopTimer() + return c.commit(db.journalFile.Num(), db.frozenSeq) + }, nil) + + db.logf("mem@flush committed F·%d T·%v", len(c.rec.addedTables), stats.duration) + + for _, r := range c.rec.addedTables { + stats.write += r.size + } + db.compStats[c.level].add(stats) + + // Drop frozen mem. + db.dropFrozenMem() + + // Resume table compaction. + if resumeC != nil { + select { + case <-resumeC: + close(resumeC) + case _, _ = <-db.closeC: + return + } + } + + // Trigger table compaction. + db.compSendTrigger(db.tcompCmdC) +} + +type tableCompactionBuilder struct { + db *DB + s *session + c *compaction + rec *sessionRecord + stat0, stat1 *cStatsStaging + + snapHasLastUkey bool + snapLastUkey []byte + snapLastSeq uint64 + snapIter int + snapKerrCnt int + snapDropCnt int + + kerrCnt int + dropCnt int + + minSeq uint64 + strict bool + tableSize int + + tw *tWriter +} + +func (b *tableCompactionBuilder) appendKV(key, value []byte) error { + // Create new table if not already. + if b.tw == nil { + // Check for pause event. + if b.db != nil { + select { + case ch := <-b.db.tcompPauseC: + b.db.pauseCompaction(ch) + case _, _ = <-b.db.closeC: + b.db.compactionExitTransact() + default: + } + } + + // Create new table. + var err error + b.tw, err = b.s.tops.create() + if err != nil { + return err + } + } + + // Write key/value into table. + return b.tw.append(key, value) +} + +func (b *tableCompactionBuilder) needFlush() bool { + return b.tw.tw.BytesLen() >= b.tableSize +} + +func (b *tableCompactionBuilder) flush() error { + t, err := b.tw.finish() + if err != nil { + return err + } + b.rec.addTableFile(b.c.level+1, t) + b.stat1.write += t.size + b.s.logf("table@build created L%d@%d N·%d S·%s %q:%q", b.c.level+1, t.file.Num(), b.tw.tw.EntriesLen(), shortenb(int(t.size)), t.imin, t.imax) + b.tw = nil + return nil +} + +func (b *tableCompactionBuilder) cleanup() { + if b.tw != nil { + b.tw.drop() + b.tw = nil + } +} + +func (b *tableCompactionBuilder) run(cnt *compactionTransactCounter) error { + snapResumed := b.snapIter > 0 + hasLastUkey := b.snapHasLastUkey // The key might has zero length, so this is necessary. + lastUkey := append([]byte{}, b.snapLastUkey...) + lastSeq := b.snapLastSeq + b.kerrCnt = b.snapKerrCnt + b.dropCnt = b.snapDropCnt + // Restore compaction state. + b.c.restore() + + defer b.cleanup() + + b.stat1.startTimer() + defer b.stat1.stopTimer() + + iter := b.c.newIterator() + defer iter.Release() + for i := 0; iter.Next(); i++ { + // Incr transact counter. + cnt.incr() + + // Skip until last state. + if i < b.snapIter { + continue + } + + resumed := false + if snapResumed { + resumed = true + snapResumed = false + } + + ikey := iter.Key() + ukey, seq, kt, kerr := parseIkey(ikey) + + if kerr == nil { + shouldStop := !resumed && b.c.shouldStopBefore(ikey) + + if !hasLastUkey || b.s.icmp.uCompare(lastUkey, ukey) != 0 { + // First occurrence of this user key. + + // Only rotate tables if ukey doesn't hop across. + if b.tw != nil && (shouldStop || b.needFlush()) { + if err := b.flush(); err != nil { + return err + } + + // Creates snapshot of the state. + b.c.save() + b.snapHasLastUkey = hasLastUkey + b.snapLastUkey = append(b.snapLastUkey[:0], lastUkey...) + b.snapLastSeq = lastSeq + b.snapIter = i + b.snapKerrCnt = b.kerrCnt + b.snapDropCnt = b.dropCnt + } + + hasLastUkey = true + lastUkey = append(lastUkey[:0], ukey...) + lastSeq = kMaxSeq + } + + switch { + case lastSeq <= b.minSeq: + // Dropped because newer entry for same user key exist + fallthrough // (A) + case kt == ktDel && seq <= b.minSeq && b.c.baseLevelForKey(lastUkey): + // For this user key: + // (1) there is no data in higher levels + // (2) data in lower levels will have larger seq numbers + // (3) data in layers that are being compacted here and have + // smaller seq numbers will be dropped in the next + // few iterations of this loop (by rule (A) above). + // Therefore this deletion marker is obsolete and can be dropped. + lastSeq = seq + b.dropCnt++ + continue + default: + lastSeq = seq + } + } else { + if b.strict { + return kerr + } + + // Don't drop corrupted keys. + hasLastUkey = false + lastUkey = lastUkey[:0] + lastSeq = kMaxSeq + b.kerrCnt++ + } + + if err := b.appendKV(ikey, iter.Value()); err != nil { + return err + } + } + + if err := iter.Error(); err != nil { + return err + } + + // Finish last table. + if b.tw != nil && !b.tw.empty() { + return b.flush() + } + return nil +} + +func (b *tableCompactionBuilder) revert() error { + for _, at := range b.rec.addedTables { + b.s.logf("table@build revert @%d", at.num) + f := b.s.getTableFile(at.num) + if err := f.Remove(); err != nil { + return err + } + } + return nil +} + +func (db *DB) tableCompaction(c *compaction, noTrivial bool) { + defer c.release() + + rec := &sessionRecord{numLevel: db.s.o.GetNumLevel()} + rec.addCompPtr(c.level, c.imax) + + if !noTrivial && c.trivial() { + t := c.tables[0][0] + db.logf("table@move L%d@%d -> L%d", c.level, t.file.Num(), c.level+1) + rec.delTable(c.level, t.file.Num()) + rec.addTableFile(c.level+1, t) + db.compactionTransactFunc("table@move", func(cnt *compactionTransactCounter) (err error) { + return db.s.commit(rec) + }, nil) + return + } + + var stats [2]cStatsStaging + for i, tables := range c.tables { + for _, t := range tables { + stats[i].read += t.size + // Insert deleted tables into record + rec.delTable(c.level+i, t.file.Num()) + } + } + sourceSize := int(stats[0].read + stats[1].read) + minSeq := db.minSeq() + db.logf("table@compaction L%d·%d -> L%d·%d S·%s Q·%d", c.level, len(c.tables[0]), c.level+1, len(c.tables[1]), shortenb(sourceSize), minSeq) + + b := &tableCompactionBuilder{ + db: db, + s: db.s, + c: c, + rec: rec, + stat1: &stats[1], + minSeq: minSeq, + strict: db.s.o.GetStrict(opt.StrictCompaction), + tableSize: db.s.o.GetCompactionTableSize(c.level + 1), + } + db.compactionTransact("table@build", b) + + // Commit changes + db.compactionTransactFunc("table@commit", func(cnt *compactionTransactCounter) (err error) { + stats[1].startTimer() + defer stats[1].stopTimer() + return db.s.commit(rec) + }, nil) + + resultSize := int(stats[1].write) + db.logf("table@compaction committed F%s S%s Ke·%d D·%d T·%v", sint(len(rec.addedTables)-len(rec.deletedTables)), sshortenb(resultSize-sourceSize), b.kerrCnt, b.dropCnt, stats[1].duration) + + // Save compaction stats + for i := range stats { + db.compStats[c.level+1].add(&stats[i]) + } +} + +func (db *DB) tableRangeCompaction(level int, umin, umax []byte) { + db.logf("table@compaction range L%d %q:%q", level, umin, umax) + + if level >= 0 { + if c := db.s.getCompactionRange(level, umin, umax); c != nil { + db.tableCompaction(c, true) + } + } else { + v := db.s.version() + m := 1 + for i, t := range v.tables[1:] { + if t.overlaps(db.s.icmp, umin, umax, false) { + m = i + 1 + } + } + v.release() + + for level := 0; level < m; level++ { + if c := db.s.getCompactionRange(level, umin, umax); c != nil { + db.tableCompaction(c, true) + } + } + } +} + +func (db *DB) tableAutoCompaction() { + if c := db.s.pickCompaction(); c != nil { + db.tableCompaction(c, false) + } +} + +func (db *DB) tableNeedCompaction() bool { + v := db.s.version() + defer v.release() + return v.needCompaction() +} + +func (db *DB) pauseCompaction(ch chan<- struct{}) { + select { + case ch <- struct{}{}: + case _, _ = <-db.closeC: + db.compactionExitTransact() + } +} + +type cCmd interface { + ack(err error) +} + +type cIdle struct { + ackC chan<- error +} + +func (r cIdle) ack(err error) { + if r.ackC != nil { + defer func() { + recover() + }() + r.ackC <- err + } +} + +type cRange struct { + level int + min, max []byte + ackC chan<- error +} + +func (r cRange) ack(err error) { + if r.ackC != nil { + defer func() { + recover() + }() + r.ackC <- err + } +} + +// This will trigger auto compation and/or wait for all compaction to be done. +func (db *DB) compSendIdle(compC chan<- cCmd) (err error) { + ch := make(chan error) + defer close(ch) + // Send cmd. + select { + case compC <- cIdle{ch}: + case err = <-db.compErrC: + return + case _, _ = <-db.closeC: + return ErrClosed + } + // Wait cmd. + select { + case err = <-ch: + case err = <-db.compErrC: + case _, _ = <-db.closeC: + return ErrClosed + } + return err +} + +// This will trigger auto compaction but will not wait for it. +func (db *DB) compSendTrigger(compC chan<- cCmd) { + select { + case compC <- cIdle{}: + default: + } +} + +// Send range compaction request. +func (db *DB) compSendRange(compC chan<- cCmd, level int, min, max []byte) (err error) { + ch := make(chan error) + defer close(ch) + // Send cmd. + select { + case compC <- cRange{level, min, max, ch}: + case err := <-db.compErrC: + return err + case _, _ = <-db.closeC: + return ErrClosed + } + // Wait cmd. + select { + case err = <-ch: + case err = <-db.compErrC: + case _, _ = <-db.closeC: + return ErrClosed + } + return err +} + +func (db *DB) mCompaction() { + var x cCmd + + defer func() { + if x := recover(); x != nil { + if x != errCompactionTransactExiting { + panic(x) + } + } + if x != nil { + x.ack(ErrClosed) + } + db.closeW.Done() + }() + + for { + select { + case x = <-db.mcompCmdC: + switch x.(type) { + case cIdle: + db.memCompaction() + x.ack(nil) + x = nil + default: + panic("leveldb: unknown command") + } + case _, _ = <-db.closeC: + return + } + } +} + +func (db *DB) tCompaction() { + var x cCmd + var ackQ []cCmd + + defer func() { + if x := recover(); x != nil { + if x != errCompactionTransactExiting { + panic(x) + } + } + for i := range ackQ { + ackQ[i].ack(ErrClosed) + ackQ[i] = nil + } + if x != nil { + x.ack(ErrClosed) + } + db.closeW.Done() + }() + + for { + if db.tableNeedCompaction() { + select { + case x = <-db.tcompCmdC: + case ch := <-db.tcompPauseC: + db.pauseCompaction(ch) + continue + case _, _ = <-db.closeC: + return + default: + } + } else { + for i := range ackQ { + ackQ[i].ack(nil) + ackQ[i] = nil + } + ackQ = ackQ[:0] + select { + case x = <-db.tcompCmdC: + case ch := <-db.tcompPauseC: + db.pauseCompaction(ch) + continue + case _, _ = <-db.closeC: + return + } + } + if x != nil { + switch cmd := x.(type) { + case cIdle: + ackQ = append(ackQ, x) + case cRange: + db.tableRangeCompaction(cmd.level, cmd.min, cmd.max) + x.ack(nil) + default: + panic("leveldb: unknown command") + } + x = nil + } + db.tableAutoCompaction() + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_iter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_iter.go new file mode 100644 index 0000000000..4607e5daf3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_iter.go @@ -0,0 +1,332 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "errors" + "runtime" + "sync" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + errInvalidIkey = errors.New("leveldb: Iterator: invalid internal key") +) + +type memdbReleaser struct { + once sync.Once + m *memDB +} + +func (mr *memdbReleaser) Release() { + mr.once.Do(func() { + mr.m.decref() + }) +} + +func (db *DB) newRawIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + em, fm := db.getMems() + v := db.s.version() + + ti := v.getIterators(slice, ro) + n := len(ti) + 2 + i := make([]iterator.Iterator, 0, n) + emi := em.mdb.NewIterator(slice) + emi.SetReleaser(&memdbReleaser{m: em}) + i = append(i, emi) + if fm != nil { + fmi := fm.mdb.NewIterator(slice) + fmi.SetReleaser(&memdbReleaser{m: fm}) + i = append(i, fmi) + } + i = append(i, ti...) + strict := opt.GetStrict(db.s.o.Options, ro, opt.StrictReader) + mi := iterator.NewMergedIterator(i, db.s.icmp, strict) + mi.SetReleaser(&versionReleaser{v: v}) + return mi +} + +func (db *DB) newIterator(seq uint64, slice *util.Range, ro *opt.ReadOptions) *dbIter { + var islice *util.Range + if slice != nil { + islice = &util.Range{} + if slice.Start != nil { + islice.Start = newIkey(slice.Start, kMaxSeq, ktSeek) + } + if slice.Limit != nil { + islice.Limit = newIkey(slice.Limit, kMaxSeq, ktSeek) + } + } + rawIter := db.newRawIterator(islice, ro) + iter := &dbIter{ + db: db, + icmp: db.s.icmp, + iter: rawIter, + seq: seq, + strict: opt.GetStrict(db.s.o.Options, ro, opt.StrictReader), + key: make([]byte, 0), + value: make([]byte, 0), + } + atomic.AddInt32(&db.aliveIters, 1) + runtime.SetFinalizer(iter, (*dbIter).Release) + return iter +} + +type dir int + +const ( + dirReleased dir = iota - 1 + dirSOI + dirEOI + dirBackward + dirForward +) + +// dbIter represent an interator states over a database session. +type dbIter struct { + db *DB + icmp *iComparer + iter iterator.Iterator + seq uint64 + strict bool + + dir dir + key []byte + value []byte + err error + releaser util.Releaser +} + +func (i *dbIter) setErr(err error) { + i.err = err + i.key = nil + i.value = nil +} + +func (i *dbIter) iterErr() { + if err := i.iter.Error(); err != nil { + i.setErr(err) + } +} + +func (i *dbIter) Valid() bool { + return i.err == nil && i.dir > dirEOI +} + +func (i *dbIter) First() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.iter.First() { + i.dir = dirSOI + return i.next() + } + i.dir = dirEOI + i.iterErr() + return false +} + +func (i *dbIter) Last() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.iter.Last() { + return i.prev() + } + i.dir = dirSOI + i.iterErr() + return false +} + +func (i *dbIter) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + ikey := newIkey(key, i.seq, ktSeek) + if i.iter.Seek(ikey) { + i.dir = dirSOI + return i.next() + } + i.dir = dirEOI + i.iterErr() + return false +} + +func (i *dbIter) next() bool { + for { + if ukey, seq, kt, kerr := parseIkey(i.iter.Key()); kerr == nil { + if seq <= i.seq { + switch kt { + case ktDel: + // Skip deleted key. + i.key = append(i.key[:0], ukey...) + i.dir = dirForward + case ktVal: + if i.dir == dirSOI || i.icmp.uCompare(ukey, i.key) > 0 { + i.key = append(i.key[:0], ukey...) + i.value = append(i.value[:0], i.iter.Value()...) + i.dir = dirForward + return true + } + } + } + } else if i.strict { + i.setErr(kerr) + break + } + if !i.iter.Next() { + i.dir = dirEOI + i.iterErr() + break + } + } + return false +} + +func (i *dbIter) Next() bool { + if i.dir == dirEOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if !i.iter.Next() || (i.dir == dirBackward && !i.iter.Next()) { + i.dir = dirEOI + i.iterErr() + return false + } + return i.next() +} + +func (i *dbIter) prev() bool { + i.dir = dirBackward + del := true + if i.iter.Valid() { + for { + if ukey, seq, kt, kerr := parseIkey(i.iter.Key()); kerr == nil { + if seq <= i.seq { + if !del && i.icmp.uCompare(ukey, i.key) < 0 { + return true + } + del = (kt == ktDel) + if !del { + i.key = append(i.key[:0], ukey...) + i.value = append(i.value[:0], i.iter.Value()...) + } + } + } else if i.strict { + i.setErr(kerr) + return false + } + if !i.iter.Prev() { + break + } + } + } + if del { + i.dir = dirSOI + i.iterErr() + return false + } + return true +} + +func (i *dbIter) Prev() bool { + if i.dir == dirSOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + switch i.dir { + case dirEOI: + return i.Last() + case dirForward: + for i.iter.Prev() { + if ukey, _, _, kerr := parseIkey(i.iter.Key()); kerr == nil { + if i.icmp.uCompare(ukey, i.key) < 0 { + goto cont + } + } else if i.strict { + i.setErr(kerr) + return false + } + } + i.dir = dirSOI + i.iterErr() + return false + } + +cont: + return i.prev() +} + +func (i *dbIter) Key() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.key +} + +func (i *dbIter) Value() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.value +} + +func (i *dbIter) Release() { + if i.dir != dirReleased { + // Clear the finalizer. + runtime.SetFinalizer(i, nil) + + if i.releaser != nil { + i.releaser.Release() + i.releaser = nil + } + + i.dir = dirReleased + i.key = nil + i.value = nil + i.iter.Release() + i.iter = nil + atomic.AddInt32(&i.db.aliveIters, -1) + i.db = nil + } +} + +func (i *dbIter) SetReleaser(releaser util.Releaser) { + if i.dir == dirReleased { + panic(util.ErrReleased) + } + if i.releaser != nil && releaser != nil { + panic(util.ErrHasReleaser) + } + i.releaser = releaser +} + +func (i *dbIter) Error() error { + return i.err +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_snapshot.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_snapshot.go new file mode 100644 index 0000000000..0372848ff1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_snapshot.go @@ -0,0 +1,183 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "container/list" + "fmt" + "runtime" + "sync" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type snapshotElement struct { + seq uint64 + ref int + e *list.Element +} + +// Acquires a snapshot, based on latest sequence. +func (db *DB) acquireSnapshot() *snapshotElement { + db.snapsMu.Lock() + defer db.snapsMu.Unlock() + + seq := db.getSeq() + + if e := db.snapsList.Back(); e != nil { + se := e.Value.(*snapshotElement) + if se.seq == seq { + se.ref++ + return se + } else if seq < se.seq { + panic("leveldb: sequence number is not increasing") + } + } + se := &snapshotElement{seq: seq, ref: 1} + se.e = db.snapsList.PushBack(se) + return se +} + +// Releases given snapshot element. +func (db *DB) releaseSnapshot(se *snapshotElement) { + db.snapsMu.Lock() + defer db.snapsMu.Unlock() + + se.ref-- + if se.ref == 0 { + db.snapsList.Remove(se.e) + se.e = nil + } else if se.ref < 0 { + panic("leveldb: Snapshot: negative element reference") + } +} + +// Gets minimum sequence that not being snapshoted. +func (db *DB) minSeq() uint64 { + db.snapsMu.Lock() + defer db.snapsMu.Unlock() + + if e := db.snapsList.Front(); e != nil { + return e.Value.(*snapshotElement).seq + } + + return db.getSeq() +} + +// Snapshot is a DB snapshot. +type Snapshot struct { + db *DB + elem *snapshotElement + mu sync.RWMutex + released bool +} + +// Creates new snapshot object. +func (db *DB) newSnapshot() *Snapshot { + snap := &Snapshot{ + db: db, + elem: db.acquireSnapshot(), + } + atomic.AddInt32(&db.aliveSnaps, 1) + runtime.SetFinalizer(snap, (*Snapshot).Release) + return snap +} + +func (snap *Snapshot) String() string { + return fmt.Sprintf("leveldb.Snapshot{%d}", snap.elem.seq) +} + +// Get gets the value for the given key. It returns ErrNotFound if +// the DB does not contains the key. +// +// The caller should not modify the contents of the returned slice, but +// it is safe to modify the contents of the argument after Get returns. +func (snap *Snapshot) Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) { + err = snap.db.ok() + if err != nil { + return + } + snap.mu.RLock() + defer snap.mu.RUnlock() + if snap.released { + err = ErrSnapshotReleased + return + } + return snap.db.get(key, snap.elem.seq, ro) +} + +// Has returns true if the DB does contains the given key. +// +// It is safe to modify the contents of the argument after Get returns. +func (snap *Snapshot) Has(key []byte, ro *opt.ReadOptions) (ret bool, err error) { + err = snap.db.ok() + if err != nil { + return + } + snap.mu.RLock() + defer snap.mu.RUnlock() + if snap.released { + err = ErrSnapshotReleased + return + } + return snap.db.has(key, snap.elem.seq, ro) +} + +// NewIterator returns an iterator for the snapshot of the uderlying DB. +// The returned iterator is not goroutine-safe, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +// It is also safe to use an iterator concurrently with modifying its +// underlying DB. The resultant key/value pairs are guaranteed to be +// consistent. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// DB. And a nil Range.Limit is treated as a key after all keys in +// the DB. +// +// The iterator must be released after use, by calling Release method. +// Releasing the snapshot doesn't mean releasing the iterator too, the +// iterator would be still valid until released. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (snap *Snapshot) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + if err := snap.db.ok(); err != nil { + return iterator.NewEmptyIterator(err) + } + snap.mu.Lock() + defer snap.mu.Unlock() + if snap.released { + return iterator.NewEmptyIterator(ErrSnapshotReleased) + } + // Since iterator already hold version ref, it doesn't need to + // hold snapshot ref. + return snap.db.newIterator(snap.elem.seq, slice, ro) +} + +// Release releases the snapshot. This will not release any returned +// iterators, the iterators would still be valid until released or the +// underlying DB is closed. +// +// Other methods should not be called after the snapshot has been released. +func (snap *Snapshot) Release() { + snap.mu.Lock() + defer snap.mu.Unlock() + + if !snap.released { + // Clear the finalizer. + runtime.SetFinalizer(snap, nil) + + snap.released = true + snap.db.releaseSnapshot(snap.elem) + atomic.AddInt32(&snap.db.aliveSnaps, -1) + snap.db = nil + snap.elem = nil + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_state.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_state.go new file mode 100644 index 0000000000..24ecab504d --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_state.go @@ -0,0 +1,202 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "sync/atomic" + "time" + + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/memdb" +) + +type memDB struct { + db *DB + mdb *memdb.DB + ref int32 +} + +func (m *memDB) incref() { + atomic.AddInt32(&m.ref, 1) +} + +func (m *memDB) decref() { + if ref := atomic.AddInt32(&m.ref, -1); ref == 0 { + // Only put back memdb with std capacity. + if m.mdb.Capacity() == m.db.s.o.GetWriteBuffer() { + m.mdb.Reset() + m.db.mpoolPut(m.mdb) + } + m.db = nil + m.mdb = nil + } else if ref < 0 { + panic("negative memdb ref") + } +} + +// Get latest sequence number. +func (db *DB) getSeq() uint64 { + return atomic.LoadUint64(&db.seq) +} + +// Atomically adds delta to seq. +func (db *DB) addSeq(delta uint64) { + atomic.AddUint64(&db.seq, delta) +} + +func (db *DB) mpoolPut(mem *memdb.DB) { + defer func() { + recover() + }() + select { + case db.memPool <- mem: + default: + } +} + +func (db *DB) mpoolGet() *memdb.DB { + select { + case mem := <-db.memPool: + return mem + default: + return nil + } +} + +func (db *DB) mpoolDrain() { + ticker := time.NewTicker(30 * time.Second) + for { + select { + case <-ticker.C: + select { + case <-db.memPool: + default: + } + case _, _ = <-db.closeC: + close(db.memPool) + return + } + } +} + +// Create new memdb and froze the old one; need external synchronization. +// newMem only called synchronously by the writer. +func (db *DB) newMem(n int) (mem *memDB, err error) { + num := db.s.allocFileNum() + file := db.s.getJournalFile(num) + w, err := file.Create() + if err != nil { + db.s.reuseFileNum(num) + return + } + + db.memMu.Lock() + defer db.memMu.Unlock() + + if db.frozenMem != nil { + panic("still has frozen mem") + } + + if db.journal == nil { + db.journal = journal.NewWriter(w) + } else { + db.journal.Reset(w) + db.journalWriter.Close() + db.frozenJournalFile = db.journalFile + } + db.journalWriter = w + db.journalFile = file + db.frozenMem = db.mem + mdb := db.mpoolGet() + if mdb == nil || mdb.Capacity() < n { + mdb = memdb.New(db.s.icmp, maxInt(db.s.o.GetWriteBuffer(), n)) + } + mem = &memDB{ + db: db, + mdb: mdb, + ref: 2, + } + db.mem = mem + // The seq only incremented by the writer. And whoever called newMem + // should hold write lock, so no need additional synchronization here. + db.frozenSeq = db.seq + return +} + +// Get all memdbs. +func (db *DB) getMems() (e, f *memDB) { + db.memMu.RLock() + defer db.memMu.RUnlock() + if db.mem == nil { + panic("nil effective mem") + } + db.mem.incref() + if db.frozenMem != nil { + db.frozenMem.incref() + } + return db.mem, db.frozenMem +} + +// Get frozen memdb. +func (db *DB) getEffectiveMem() *memDB { + db.memMu.RLock() + defer db.memMu.RUnlock() + if db.mem == nil { + panic("nil effective mem") + } + db.mem.incref() + return db.mem +} + +// Check whether we has frozen memdb. +func (db *DB) hasFrozenMem() bool { + db.memMu.RLock() + defer db.memMu.RUnlock() + return db.frozenMem != nil +} + +// Get frozen memdb. +func (db *DB) getFrozenMem() *memDB { + db.memMu.RLock() + defer db.memMu.RUnlock() + if db.frozenMem != nil { + db.frozenMem.incref() + } + return db.frozenMem +} + +// Drop frozen memdb; assume that frozen memdb isn't nil. +func (db *DB) dropFrozenMem() { + db.memMu.Lock() + if err := db.frozenJournalFile.Remove(); err != nil { + db.logf("journal@remove removing @%d %q", db.frozenJournalFile.Num(), err) + } else { + db.logf("journal@remove removed @%d", db.frozenJournalFile.Num()) + } + db.frozenJournalFile = nil + db.frozenMem.decref() + db.frozenMem = nil + db.memMu.Unlock() +} + +// Set closed flag; return true if not already closed. +func (db *DB) setClosed() bool { + return atomic.CompareAndSwapUint32(&db.closed, 0, 1) +} + +// Check whether DB was closed. +func (db *DB) isClosed() bool { + return atomic.LoadUint32(&db.closed) != 0 +} + +// Check read ok status. +func (db *DB) ok() error { + if db.isClosed() { + return ErrClosed + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_test.go new file mode 100644 index 0000000000..0acb567a16 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_test.go @@ -0,0 +1,2579 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "container/list" + crand "crypto/rand" + "encoding/binary" + "fmt" + "math/rand" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + "unsafe" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func tkey(i int) []byte { + return []byte(fmt.Sprintf("%016d", i)) +} + +func tval(seed, n int) []byte { + r := rand.New(rand.NewSource(int64(seed))) + return randomString(r, n) +} + +type dbHarness struct { + t *testing.T + + stor *testStorage + db *DB + o *opt.Options + ro *opt.ReadOptions + wo *opt.WriteOptions +} + +func newDbHarnessWopt(t *testing.T, o *opt.Options) *dbHarness { + h := new(dbHarness) + h.init(t, o) + return h +} + +func newDbHarness(t *testing.T) *dbHarness { + return newDbHarnessWopt(t, &opt.Options{}) +} + +func (h *dbHarness) init(t *testing.T, o *opt.Options) { + h.t = t + h.stor = newTestStorage(t) + h.o = o + h.ro = nil + h.wo = nil + + if err := h.openDB0(); err != nil { + // So that it will come after fatal message. + defer h.stor.Close() + h.t.Fatal("Open (init): got error: ", err) + } +} + +func (h *dbHarness) openDB0() (err error) { + h.t.Log("opening DB") + h.db, err = Open(h.stor, h.o) + return +} + +func (h *dbHarness) openDB() { + if err := h.openDB0(); err != nil { + h.t.Fatal("Open: got error: ", err) + } +} + +func (h *dbHarness) closeDB0() error { + h.t.Log("closing DB") + return h.db.Close() +} + +func (h *dbHarness) closeDB() { + if err := h.closeDB0(); err != nil { + h.t.Error("Close: got error: ", err) + } + h.stor.CloseCheck() + runtime.GC() +} + +func (h *dbHarness) reopenDB() { + h.closeDB() + h.openDB() +} + +func (h *dbHarness) close() { + h.closeDB0() + h.db = nil + h.stor.Close() + h.stor = nil + runtime.GC() +} + +func (h *dbHarness) openAssert(want bool) { + db, err := Open(h.stor, h.o) + if err != nil { + if want { + h.t.Error("Open: assert: got error: ", err) + } else { + h.t.Log("Open: assert: got error (expected): ", err) + } + } else { + if !want { + h.t.Error("Open: assert: expect error") + } + db.Close() + } +} + +func (h *dbHarness) write(batch *Batch) { + if err := h.db.Write(batch, h.wo); err != nil { + h.t.Error("Write: got error: ", err) + } +} + +func (h *dbHarness) put(key, value string) { + if err := h.db.Put([]byte(key), []byte(value), h.wo); err != nil { + h.t.Error("Put: got error: ", err) + } +} + +func (h *dbHarness) putMulti(n int, low, hi string) { + for i := 0; i < n; i++ { + h.put(low, "begin") + h.put(hi, "end") + h.compactMem() + } +} + +func (h *dbHarness) maxNextLevelOverlappingBytes(want uint64) { + t := h.t + db := h.db + + var ( + maxOverlaps uint64 + maxLevel int + ) + v := db.s.version() + for i, tt := range v.tables[1 : len(v.tables)-1] { + level := i + 1 + next := v.tables[level+1] + for _, t := range tt { + r := next.getOverlaps(nil, db.s.icmp, t.imin.ukey(), t.imax.ukey(), false) + sum := r.size() + if sum > maxOverlaps { + maxOverlaps = sum + maxLevel = level + } + } + } + v.release() + + if maxOverlaps > want { + t.Errorf("next level most overlapping bytes is more than %d, got=%d level=%d", want, maxOverlaps, maxLevel) + } else { + t.Logf("next level most overlapping bytes is %d, level=%d want=%d", maxOverlaps, maxLevel, want) + } +} + +func (h *dbHarness) delete(key string) { + t := h.t + db := h.db + + err := db.Delete([]byte(key), h.wo) + if err != nil { + t.Error("Delete: got error: ", err) + } +} + +func (h *dbHarness) assertNumKeys(want int) { + iter := h.db.NewIterator(nil, h.ro) + defer iter.Release() + got := 0 + for iter.Next() { + got++ + } + if err := iter.Error(); err != nil { + h.t.Error("assertNumKeys: ", err) + } + if want != got { + h.t.Errorf("assertNumKeys: want=%d got=%d", want, got) + } +} + +func (h *dbHarness) getr(db Reader, key string, expectFound bool) (found bool, v []byte) { + t := h.t + v, err := db.Get([]byte(key), h.ro) + switch err { + case ErrNotFound: + if expectFound { + t.Errorf("Get: key '%s' not found, want found", key) + } + case nil: + found = true + if !expectFound { + t.Errorf("Get: key '%s' found, want not found", key) + } + default: + t.Error("Get: got error: ", err) + } + return +} + +func (h *dbHarness) get(key string, expectFound bool) (found bool, v []byte) { + return h.getr(h.db, key, expectFound) +} + +func (h *dbHarness) getValr(db Reader, key, value string) { + t := h.t + found, r := h.getr(db, key, true) + if !found { + return + } + rval := string(r) + if rval != value { + t.Errorf("Get: invalid value, got '%s', want '%s'", rval, value) + } +} + +func (h *dbHarness) getVal(key, value string) { + h.getValr(h.db, key, value) +} + +func (h *dbHarness) allEntriesFor(key, want string) { + t := h.t + db := h.db + s := db.s + + ikey := newIkey([]byte(key), kMaxSeq, ktVal) + iter := db.newRawIterator(nil, nil) + if !iter.Seek(ikey) && iter.Error() != nil { + t.Error("AllEntries: error during seek, err: ", iter.Error()) + return + } + res := "[ " + first := true + for iter.Valid() { + if ukey, _, kt, kerr := parseIkey(iter.Key()); kerr == nil { + if s.icmp.uCompare(ikey.ukey(), ukey) != 0 { + break + } + if !first { + res += ", " + } + first = false + switch kt { + case ktVal: + res += string(iter.Value()) + case ktDel: + res += "DEL" + } + } else { + if !first { + res += ", " + } + first = false + res += "CORRUPTED" + } + iter.Next() + } + if !first { + res += " " + } + res += "]" + if res != want { + t.Errorf("AllEntries: assert failed for key %q, got=%q want=%q", key, res, want) + } +} + +// Return a string that contains all key,value pairs in order, +// formatted like "(k1->v1)(k2->v2)". +func (h *dbHarness) getKeyVal(want string) { + t := h.t + db := h.db + + s, err := db.GetSnapshot() + if err != nil { + t.Fatal("GetSnapshot: got error: ", err) + } + res := "" + iter := s.NewIterator(nil, nil) + for iter.Next() { + res += fmt.Sprintf("(%s->%s)", string(iter.Key()), string(iter.Value())) + } + iter.Release() + + if res != want { + t.Errorf("GetKeyVal: invalid key/value pair, got=%q want=%q", res, want) + } + s.Release() +} + +func (h *dbHarness) waitCompaction() { + t := h.t + db := h.db + if err := db.compSendIdle(db.tcompCmdC); err != nil { + t.Error("compaction error: ", err) + } +} + +func (h *dbHarness) waitMemCompaction() { + t := h.t + db := h.db + + if err := db.compSendIdle(db.mcompCmdC); err != nil { + t.Error("compaction error: ", err) + } +} + +func (h *dbHarness) compactMem() { + t := h.t + db := h.db + + t.Log("starting memdb compaction") + + db.writeLockC <- struct{}{} + defer func() { + <-db.writeLockC + }() + + if _, err := db.rotateMem(0); err != nil { + t.Error("compaction error: ", err) + } + if err := db.compSendIdle(db.mcompCmdC); err != nil { + t.Error("compaction error: ", err) + } + + if h.totalTables() == 0 { + t.Error("zero tables after mem compaction") + } + + t.Log("memdb compaction done") +} + +func (h *dbHarness) compactRangeAtErr(level int, min, max string, wanterr bool) { + t := h.t + db := h.db + + var _min, _max []byte + if min != "" { + _min = []byte(min) + } + if max != "" { + _max = []byte(max) + } + + t.Logf("starting table range compaction: level=%d, min=%q, max=%q", level, min, max) + + if err := db.compSendRange(db.tcompCmdC, level, _min, _max); err != nil { + if wanterr { + t.Log("CompactRangeAt: got error (expected): ", err) + } else { + t.Error("CompactRangeAt: got error: ", err) + } + } else if wanterr { + t.Error("CompactRangeAt: expect error") + } + + t.Log("table range compaction done") +} + +func (h *dbHarness) compactRangeAt(level int, min, max string) { + h.compactRangeAtErr(level, min, max, false) +} + +func (h *dbHarness) compactRange(min, max string) { + t := h.t + db := h.db + + t.Logf("starting DB range compaction: min=%q, max=%q", min, max) + + var r util.Range + if min != "" { + r.Start = []byte(min) + } + if max != "" { + r.Limit = []byte(max) + } + if err := db.CompactRange(r); err != nil { + t.Error("CompactRange: got error: ", err) + } + + t.Log("DB range compaction done") +} + +func (h *dbHarness) sizeAssert(start, limit string, low, hi uint64) { + t := h.t + db := h.db + + s, err := db.SizeOf([]util.Range{ + {[]byte(start), []byte(limit)}, + }) + if err != nil { + t.Error("SizeOf: got error: ", err) + } + if s.Sum() < low || s.Sum() > hi { + t.Errorf("sizeof %q to %q not in range, want %d - %d, got %d", + shorten(start), shorten(limit), low, hi, s.Sum()) + } +} + +func (h *dbHarness) getSnapshot() (s *Snapshot) { + s, err := h.db.GetSnapshot() + if err != nil { + h.t.Fatal("GetSnapshot: got error: ", err) + } + return +} +func (h *dbHarness) tablesPerLevel(want string) { + res := "" + nz := 0 + v := h.db.s.version() + for level, tt := range v.tables { + if level > 0 { + res += "," + } + res += fmt.Sprint(len(tt)) + if len(tt) > 0 { + nz = len(res) + } + } + v.release() + res = res[:nz] + if res != want { + h.t.Errorf("invalid tables len, want=%s, got=%s", want, res) + } +} + +func (h *dbHarness) totalTables() (n int) { + v := h.db.s.version() + for _, tt := range v.tables { + n += len(tt) + } + v.release() + return +} + +type keyValue interface { + Key() []byte + Value() []byte +} + +func testKeyVal(t *testing.T, kv keyValue, want string) { + res := string(kv.Key()) + "->" + string(kv.Value()) + if res != want { + t.Errorf("invalid key/value, want=%q, got=%q", want, res) + } +} + +func numKey(num int) string { + return fmt.Sprintf("key%06d", num) +} + +var _bloom_filter = filter.NewBloomFilter(10) + +func truno(t *testing.T, o *opt.Options, f func(h *dbHarness)) { + for i := 0; i < 4; i++ { + func() { + switch i { + case 0: + case 1: + if o == nil { + o = &opt.Options{Filter: _bloom_filter} + } else { + old := o + o = &opt.Options{} + *o = *old + o.Filter = _bloom_filter + } + case 2: + if o == nil { + o = &opt.Options{Compression: opt.NoCompression} + } else { + old := o + o = &opt.Options{} + *o = *old + o.Compression = opt.NoCompression + } + } + h := newDbHarnessWopt(t, o) + defer h.close() + switch i { + case 3: + h.reopenDB() + } + f(h) + }() + } +} + +func trun(t *testing.T, f func(h *dbHarness)) { + truno(t, nil, f) +} + +func testAligned(t *testing.T, name string, offset uintptr) { + if offset%8 != 0 { + t.Errorf("field %s offset is not 64-bit aligned", name) + } +} + +func Test_FieldsAligned(t *testing.T) { + p1 := new(DB) + testAligned(t, "DB.seq", unsafe.Offsetof(p1.seq)) + p2 := new(session) + testAligned(t, "session.stNextFileNum", unsafe.Offsetof(p2.stNextFileNum)) + testAligned(t, "session.stJournalNum", unsafe.Offsetof(p2.stJournalNum)) + testAligned(t, "session.stPrevJournalNum", unsafe.Offsetof(p2.stPrevJournalNum)) + testAligned(t, "session.stSeqNum", unsafe.Offsetof(p2.stSeqNum)) +} + +func TestDB_Locking(t *testing.T) { + h := newDbHarness(t) + defer h.stor.Close() + h.openAssert(false) + h.closeDB() + h.openAssert(true) +} + +func TestDB_Empty(t *testing.T) { + trun(t, func(h *dbHarness) { + h.get("foo", false) + + h.reopenDB() + h.get("foo", false) + }) +} + +func TestDB_ReadWrite(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.getVal("foo", "v1") + h.put("bar", "v2") + h.put("foo", "v3") + h.getVal("foo", "v3") + h.getVal("bar", "v2") + + h.reopenDB() + h.getVal("foo", "v3") + h.getVal("bar", "v2") + }) +} + +func TestDB_PutDeleteGet(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.getVal("foo", "v1") + h.put("foo", "v2") + h.getVal("foo", "v2") + h.delete("foo") + h.get("foo", false) + + h.reopenDB() + h.get("foo", false) + }) +} + +func TestDB_EmptyBatch(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.get("foo", false) + err := h.db.Write(new(Batch), h.wo) + if err != nil { + t.Error("writing empty batch yield error: ", err) + } + h.get("foo", false) +} + +func TestDB_GetFromFrozen(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{WriteBuffer: 100100}) + defer h.close() + + h.put("foo", "v1") + h.getVal("foo", "v1") + + h.stor.DelaySync(storage.TypeTable) // Block sync calls + h.put("k1", strings.Repeat("x", 100000)) // Fill memtable + h.put("k2", strings.Repeat("y", 100000)) // Trigger compaction + for i := 0; h.db.getFrozenMem() == nil && i < 100; i++ { + time.Sleep(10 * time.Microsecond) + } + if h.db.getFrozenMem() == nil { + h.stor.ReleaseSync(storage.TypeTable) + t.Fatal("No frozen mem") + } + h.getVal("foo", "v1") + h.stor.ReleaseSync(storage.TypeTable) // Release sync calls + + h.reopenDB() + h.getVal("foo", "v1") + h.get("k1", true) + h.get("k2", true) +} + +func TestDB_GetFromTable(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.compactMem() + h.getVal("foo", "v1") + }) +} + +func TestDB_GetSnapshot(t *testing.T) { + trun(t, func(h *dbHarness) { + bar := strings.Repeat("b", 200) + h.put("foo", "v1") + h.put(bar, "v1") + + snap, err := h.db.GetSnapshot() + if err != nil { + t.Fatal("GetSnapshot: got error: ", err) + } + + h.put("foo", "v2") + h.put(bar, "v2") + + h.getVal("foo", "v2") + h.getVal(bar, "v2") + h.getValr(snap, "foo", "v1") + h.getValr(snap, bar, "v1") + + h.compactMem() + + h.getVal("foo", "v2") + h.getVal(bar, "v2") + h.getValr(snap, "foo", "v1") + h.getValr(snap, bar, "v1") + + snap.Release() + + h.reopenDB() + h.getVal("foo", "v2") + h.getVal(bar, "v2") + }) +} + +func TestDB_GetLevel0Ordering(t *testing.T) { + trun(t, func(h *dbHarness) { + for i := 0; i < 4; i++ { + h.put("bar", fmt.Sprintf("b%d", i)) + h.put("foo", fmt.Sprintf("v%d", i)) + h.compactMem() + } + h.getVal("foo", "v3") + h.getVal("bar", "b3") + + v := h.db.s.version() + t0len := v.tLen(0) + v.release() + if t0len < 2 { + t.Errorf("level-0 tables is less than 2, got %d", t0len) + } + + h.reopenDB() + h.getVal("foo", "v3") + h.getVal("bar", "b3") + }) +} + +func TestDB_GetOrderedByLevels(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.compactMem() + h.compactRange("a", "z") + h.getVal("foo", "v1") + h.put("foo", "v2") + h.compactMem() + h.getVal("foo", "v2") + }) +} + +func TestDB_GetPicksCorrectFile(t *testing.T) { + trun(t, func(h *dbHarness) { + // Arrange to have multiple files in a non-level-0 level. + h.put("a", "va") + h.compactMem() + h.compactRange("a", "b") + h.put("x", "vx") + h.compactMem() + h.compactRange("x", "y") + h.put("f", "vf") + h.compactMem() + h.compactRange("f", "g") + + h.getVal("a", "va") + h.getVal("f", "vf") + h.getVal("x", "vx") + + h.compactRange("", "") + h.getVal("a", "va") + h.getVal("f", "vf") + h.getVal("x", "vx") + }) +} + +func TestDB_GetEncountersEmptyLevel(t *testing.T) { + trun(t, func(h *dbHarness) { + // Arrange for the following to happen: + // * sstable A in level 0 + // * nothing in level 1 + // * sstable B in level 2 + // Then do enough Get() calls to arrange for an automatic compaction + // of sstable A. A bug would cause the compaction to be marked as + // occuring at level 1 (instead of the correct level 0). + + // Step 1: First place sstables in levels 0 and 2 + for i := 0; ; i++ { + if i >= 100 { + t.Fatal("could not fill levels-0 and level-2") + } + v := h.db.s.version() + if v.tLen(0) > 0 && v.tLen(2) > 0 { + v.release() + break + } + v.release() + h.put("a", "begin") + h.put("z", "end") + h.compactMem() + + h.getVal("a", "begin") + h.getVal("z", "end") + } + + // Step 2: clear level 1 if necessary. + h.compactRangeAt(1, "", "") + h.tablesPerLevel("1,0,1") + + h.getVal("a", "begin") + h.getVal("z", "end") + + // Step 3: read a bunch of times + for i := 0; i < 200; i++ { + h.get("missing", false) + } + + // Step 4: Wait for compaction to finish + h.waitCompaction() + + v := h.db.s.version() + if v.tLen(0) > 0 { + t.Errorf("level-0 tables more than 0, got %d", v.tLen(0)) + } + v.release() + + h.getVal("a", "begin") + h.getVal("z", "end") + }) +} + +func TestDB_IterMultiWithDelete(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("a", "va") + h.put("b", "vb") + h.put("c", "vc") + h.delete("b") + h.get("b", false) + + iter := h.db.NewIterator(nil, nil) + iter.Seek([]byte("c")) + testKeyVal(t, iter, "c->vc") + iter.Prev() + testKeyVal(t, iter, "a->va") + iter.Release() + + h.compactMem() + + iter = h.db.NewIterator(nil, nil) + iter.Seek([]byte("c")) + testKeyVal(t, iter, "c->vc") + iter.Prev() + testKeyVal(t, iter, "a->va") + iter.Release() + }) +} + +func TestDB_IteratorPinsRef(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("foo", "hello") + + // Get iterator that will yield the current contents of the DB. + iter := h.db.NewIterator(nil, nil) + + // Write to force compactions + h.put("foo", "newvalue1") + for i := 0; i < 100; i++ { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), 100000/10)) + } + h.put("foo", "newvalue2") + + iter.First() + testKeyVal(t, iter, "foo->hello") + if iter.Next() { + t.Errorf("expect eof") + } + iter.Release() +} + +func TestDB_Recover(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.put("baz", "v5") + + h.reopenDB() + h.getVal("foo", "v1") + + h.getVal("foo", "v1") + h.getVal("baz", "v5") + h.put("bar", "v2") + h.put("foo", "v3") + + h.reopenDB() + h.getVal("foo", "v3") + h.put("foo", "v4") + h.getVal("foo", "v4") + h.getVal("bar", "v2") + h.getVal("baz", "v5") + }) +} + +func TestDB_RecoverWithEmptyJournal(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + h.put("foo", "v2") + + h.reopenDB() + h.reopenDB() + h.put("foo", "v3") + + h.reopenDB() + h.getVal("foo", "v3") + }) +} + +func TestDB_RecoverDuringMemtableCompaction(t *testing.T) { + truno(t, &opt.Options{WriteBuffer: 1000000}, func(h *dbHarness) { + + h.stor.DelaySync(storage.TypeTable) + h.put("big1", strings.Repeat("x", 10000000)) + h.put("big2", strings.Repeat("y", 1000)) + h.put("bar", "v2") + h.stor.ReleaseSync(storage.TypeTable) + + h.reopenDB() + h.getVal("bar", "v2") + h.getVal("big1", strings.Repeat("x", 10000000)) + h.getVal("big2", strings.Repeat("y", 1000)) + }) +} + +func TestDB_MinorCompactionsHappen(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{WriteBuffer: 10000}) + defer h.close() + + n := 500 + + key := func(i int) string { + return fmt.Sprintf("key%06d", i) + } + + for i := 0; i < n; i++ { + h.put(key(i), key(i)+strings.Repeat("v", 1000)) + } + + for i := 0; i < n; i++ { + h.getVal(key(i), key(i)+strings.Repeat("v", 1000)) + } + + h.reopenDB() + for i := 0; i < n; i++ { + h.getVal(key(i), key(i)+strings.Repeat("v", 1000)) + } +} + +func TestDB_RecoverWithLargeJournal(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("big1", strings.Repeat("1", 200000)) + h.put("big2", strings.Repeat("2", 200000)) + h.put("small3", strings.Repeat("3", 10)) + h.put("small4", strings.Repeat("4", 10)) + h.tablesPerLevel("") + + // Make sure that if we re-open with a small write buffer size that + // we flush table files in the middle of a large journal file. + h.o.WriteBuffer = 100000 + h.reopenDB() + h.getVal("big1", strings.Repeat("1", 200000)) + h.getVal("big2", strings.Repeat("2", 200000)) + h.getVal("small3", strings.Repeat("3", 10)) + h.getVal("small4", strings.Repeat("4", 10)) + v := h.db.s.version() + if v.tLen(0) <= 1 { + t.Errorf("tables-0 less than one") + } + v.release() +} + +func TestDB_CompactionsGenerateMultipleFiles(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + WriteBuffer: 10000000, + Compression: opt.NoCompression, + }) + defer h.close() + + v := h.db.s.version() + if v.tLen(0) > 0 { + t.Errorf("level-0 tables more than 0, got %d", v.tLen(0)) + } + v.release() + + n := 80 + + // Write 8MB (80 values, each 100K) + for i := 0; i < n; i++ { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), 100000/10)) + } + + // Reopening moves updates to level-0 + h.reopenDB() + h.compactRangeAt(0, "", "") + + v = h.db.s.version() + if v.tLen(0) > 0 { + t.Errorf("level-0 tables more than 0, got %d", v.tLen(0)) + } + if v.tLen(1) <= 1 { + t.Errorf("level-1 tables less than 1, got %d", v.tLen(1)) + } + v.release() + + for i := 0; i < n; i++ { + h.getVal(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), 100000/10)) + } +} + +func TestDB_RepeatedWritesToSameKey(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{WriteBuffer: 100000}) + defer h.close() + + maxTables := h.o.GetNumLevel() + h.o.GetWriteL0PauseTrigger() + + value := strings.Repeat("v", 2*h.o.GetWriteBuffer()) + for i := 0; i < 5*maxTables; i++ { + h.put("key", value) + n := h.totalTables() + if n > maxTables { + t.Errorf("total tables exceed %d, got=%d, iter=%d", maxTables, n, i) + } + } +} + +func TestDB_RepeatedWritesToSameKeyAfterReopen(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{WriteBuffer: 100000}) + defer h.close() + + h.reopenDB() + + maxTables := h.o.GetNumLevel() + h.o.GetWriteL0PauseTrigger() + + value := strings.Repeat("v", 2*h.o.GetWriteBuffer()) + for i := 0; i < 5*maxTables; i++ { + h.put("key", value) + n := h.totalTables() + if n > maxTables { + t.Errorf("total tables exceed %d, got=%d, iter=%d", maxTables, n, i) + } + } +} + +func TestDB_SparseMerge(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{Compression: opt.NoCompression}) + defer h.close() + + h.putMulti(h.o.GetNumLevel(), "A", "Z") + + // Suppose there is: + // small amount of data with prefix A + // large amount of data with prefix B + // small amount of data with prefix C + // and that recent updates have made small changes to all three prefixes. + // Check that we do not do a compaction that merges all of B in one shot. + h.put("A", "va") + value := strings.Repeat("x", 1000) + for i := 0; i < 100000; i++ { + h.put(fmt.Sprintf("B%010d", i), value) + } + h.put("C", "vc") + h.compactMem() + h.compactRangeAt(0, "", "") + h.waitCompaction() + + // Make sparse update + h.put("A", "va2") + h.put("B100", "bvalue2") + h.put("C", "vc2") + h.compactMem() + + h.waitCompaction() + h.maxNextLevelOverlappingBytes(20 * 1048576) + h.compactRangeAt(0, "", "") + h.waitCompaction() + h.maxNextLevelOverlappingBytes(20 * 1048576) + h.compactRangeAt(1, "", "") + h.waitCompaction() + h.maxNextLevelOverlappingBytes(20 * 1048576) +} + +func TestDB_SizeOf(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + Compression: opt.NoCompression, + WriteBuffer: 10000000, + }) + defer h.close() + + h.sizeAssert("", "xyz", 0, 0) + h.reopenDB() + h.sizeAssert("", "xyz", 0, 0) + + // Write 8MB (80 values, each 100K) + n := 80 + s1 := 100000 + s2 := 105000 + + for i := 0; i < n; i++ { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), s1/10)) + } + + // 0 because SizeOf() does not account for memtable space + h.sizeAssert("", numKey(50), 0, 0) + + for r := 0; r < 3; r++ { + h.reopenDB() + + for cs := 0; cs < n; cs += 10 { + for i := 0; i < n; i += 10 { + h.sizeAssert("", numKey(i), uint64(s1*i), uint64(s2*i)) + h.sizeAssert("", numKey(i)+".suffix", uint64(s1*(i+1)), uint64(s2*(i+1))) + h.sizeAssert(numKey(i), numKey(i+10), uint64(s1*10), uint64(s2*10)) + } + + h.sizeAssert("", numKey(50), uint64(s1*50), uint64(s2*50)) + h.sizeAssert("", numKey(50)+".suffix", uint64(s1*50), uint64(s2*50)) + + h.compactRangeAt(0, numKey(cs), numKey(cs+9)) + } + + v := h.db.s.version() + if v.tLen(0) != 0 { + t.Errorf("level-0 tables was not zero, got %d", v.tLen(0)) + } + if v.tLen(1) == 0 { + t.Error("level-1 tables was zero") + } + v.release() + } +} + +func TestDB_SizeOf_MixOfSmallAndLarge(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{Compression: opt.NoCompression}) + defer h.close() + + sizes := []uint64{ + 10000, + 10000, + 100000, + 10000, + 100000, + 10000, + 300000, + 10000, + } + + for i, n := range sizes { + h.put(numKey(i), strings.Repeat(fmt.Sprintf("v%09d", i), int(n)/10)) + } + + for r := 0; r < 3; r++ { + h.reopenDB() + + var x uint64 + for i, n := range sizes { + y := x + if i > 0 { + y += 1000 + } + h.sizeAssert("", numKey(i), x, y) + x += n + } + + h.sizeAssert(numKey(3), numKey(5), 110000, 111000) + + h.compactRangeAt(0, "", "") + } +} + +func TestDB_Snapshot(t *testing.T) { + trun(t, func(h *dbHarness) { + h.put("foo", "v1") + s1 := h.getSnapshot() + h.put("foo", "v2") + s2 := h.getSnapshot() + h.put("foo", "v3") + s3 := h.getSnapshot() + h.put("foo", "v4") + + h.getValr(s1, "foo", "v1") + h.getValr(s2, "foo", "v2") + h.getValr(s3, "foo", "v3") + h.getVal("foo", "v4") + + s3.Release() + h.getValr(s1, "foo", "v1") + h.getValr(s2, "foo", "v2") + h.getVal("foo", "v4") + + s1.Release() + h.getValr(s2, "foo", "v2") + h.getVal("foo", "v4") + + s2.Release() + h.getVal("foo", "v4") + }) +} + +func TestDB_SnapshotList(t *testing.T) { + db := &DB{snapsList: list.New()} + e0a := db.acquireSnapshot() + e0b := db.acquireSnapshot() + db.seq = 1 + e1 := db.acquireSnapshot() + db.seq = 2 + e2 := db.acquireSnapshot() + + if db.minSeq() != 0 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e0a) + if db.minSeq() != 0 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e2) + if db.minSeq() != 0 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e0b) + if db.minSeq() != 1 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + e2 = db.acquireSnapshot() + if db.minSeq() != 1 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e1) + if db.minSeq() != 2 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } + db.releaseSnapshot(e2) + if db.minSeq() != 2 { + t.Fatalf("invalid sequence number, got=%d", db.minSeq()) + } +} + +func TestDB_HiddenValuesAreRemoved(t *testing.T) { + trun(t, func(h *dbHarness) { + s := h.db.s + + h.put("foo", "v1") + h.compactMem() + m := h.o.GetMaxMemCompationLevel() + v := s.version() + num := v.tLen(m) + v.release() + if num != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, num) + } + + // Place a table at level last-1 to prevent merging with preceding mutation + h.put("a", "begin") + h.put("z", "end") + h.compactMem() + v = s.version() + if v.tLen(m) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, v.tLen(m)) + } + if v.tLen(m-1) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m-1, v.tLen(m-1)) + } + v.release() + + h.delete("foo") + h.put("foo", "v2") + h.allEntriesFor("foo", "[ v2, DEL, v1 ]") + h.compactMem() + h.allEntriesFor("foo", "[ v2, DEL, v1 ]") + h.compactRangeAt(m-2, "", "z") + // DEL eliminated, but v1 remains because we aren't compacting that level + // (DEL can be eliminated because v2 hides v1). + h.allEntriesFor("foo", "[ v2, v1 ]") + h.compactRangeAt(m-1, "", "") + // Merging last-1 w/ last, so we are the base level for "foo", so + // DEL is removed. (as is v1). + h.allEntriesFor("foo", "[ v2 ]") + }) +} + +func TestDB_DeletionMarkers2(t *testing.T) { + h := newDbHarness(t) + defer h.close() + s := h.db.s + + h.put("foo", "v1") + h.compactMem() + m := h.o.GetMaxMemCompationLevel() + v := s.version() + num := v.tLen(m) + v.release() + if num != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, num) + } + + // Place a table at level last-1 to prevent merging with preceding mutation + h.put("a", "begin") + h.put("z", "end") + h.compactMem() + v = s.version() + if v.tLen(m) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m, v.tLen(m)) + } + if v.tLen(m-1) != 1 { + t.Errorf("invalid level-%d len, want=1 got=%d", m-1, v.tLen(m-1)) + } + v.release() + + h.delete("foo") + h.allEntriesFor("foo", "[ DEL, v1 ]") + h.compactMem() // Moves to level last-2 + h.allEntriesFor("foo", "[ DEL, v1 ]") + h.compactRangeAt(m-2, "", "") + // DEL kept: "last" file overlaps + h.allEntriesFor("foo", "[ DEL, v1 ]") + h.compactRangeAt(m-1, "", "") + // Merging last-1 w/ last, so we are the base level for "foo", so + // DEL is removed. (as is v1). + h.allEntriesFor("foo", "[ ]") +} + +func TestDB_CompactionTableOpenError(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{OpenFilesCacheCapacity: -1}) + defer h.close() + + im := 10 + jm := 10 + for r := 0; r < 2; r++ { + for i := 0; i < im; i++ { + for j := 0; j < jm; j++ { + h.put(fmt.Sprintf("k%d,%d", i, j), fmt.Sprintf("v%d,%d", i, j)) + } + h.compactMem() + } + } + + if n := h.totalTables(); n != im*2 { + t.Errorf("total tables is %d, want %d", n, im) + } + + h.stor.SetEmuErr(storage.TypeTable, tsOpOpen) + go h.db.CompactRange(util.Range{}) + if err := h.db.compSendIdle(h.db.tcompCmdC); err != nil { + t.Log("compaction error: ", err) + } + h.closeDB0() + h.openDB() + h.stor.SetEmuErr(0, tsOpOpen) + + for i := 0; i < im; i++ { + for j := 0; j < jm; j++ { + h.getVal(fmt.Sprintf("k%d,%d", i, j), fmt.Sprintf("v%d,%d", i, j)) + } + } +} + +func TestDB_OverlapInLevel0(t *testing.T) { + trun(t, func(h *dbHarness) { + if h.o.GetMaxMemCompationLevel() != 2 { + t.Fatal("fix test to reflect the config") + } + + // Fill levels 1 and 2 to disable the pushing of new memtables to levels > 0. + h.put("100", "v100") + h.put("999", "v999") + h.compactMem() + h.delete("100") + h.delete("999") + h.compactMem() + h.tablesPerLevel("0,1,1") + + // Make files spanning the following ranges in level-0: + // files[0] 200 .. 900 + // files[1] 300 .. 500 + // Note that files are sorted by min key. + h.put("300", "v300") + h.put("500", "v500") + h.compactMem() + h.put("200", "v200") + h.put("600", "v600") + h.put("900", "v900") + h.compactMem() + h.tablesPerLevel("2,1,1") + + // Compact away the placeholder files we created initially + h.compactRangeAt(1, "", "") + h.compactRangeAt(2, "", "") + h.tablesPerLevel("2") + + // Do a memtable compaction. Before bug-fix, the compaction would + // not detect the overlap with level-0 files and would incorrectly place + // the deletion in a deeper level. + h.delete("600") + h.compactMem() + h.tablesPerLevel("3") + h.get("600", false) + }) +} + +func TestDB_L0_CompactionBug_Issue44_a(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.reopenDB() + h.put("b", "v") + h.reopenDB() + h.delete("b") + h.delete("a") + h.reopenDB() + h.delete("a") + h.reopenDB() + h.put("a", "v") + h.reopenDB() + h.reopenDB() + h.getKeyVal("(a->v)") + h.waitCompaction() + h.getKeyVal("(a->v)") +} + +func TestDB_L0_CompactionBug_Issue44_b(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.reopenDB() + h.put("", "") + h.reopenDB() + h.delete("e") + h.put("", "") + h.reopenDB() + h.put("c", "cv") + h.reopenDB() + h.put("", "") + h.reopenDB() + h.put("", "") + h.waitCompaction() + h.reopenDB() + h.put("d", "dv") + h.reopenDB() + h.put("", "") + h.reopenDB() + h.delete("d") + h.delete("b") + h.reopenDB() + h.getKeyVal("(->)(c->cv)") + h.waitCompaction() + h.getKeyVal("(->)(c->cv)") +} + +func TestDB_SingleEntryMemCompaction(t *testing.T) { + trun(t, func(h *dbHarness) { + for i := 0; i < 10; i++ { + h.put("big", strings.Repeat("v", opt.DefaultWriteBuffer)) + h.compactMem() + h.put("key", strings.Repeat("v", opt.DefaultBlockSize)) + h.compactMem() + h.put("k", "v") + h.compactMem() + h.put("", "") + h.compactMem() + h.put("verybig", strings.Repeat("v", opt.DefaultWriteBuffer*2)) + h.compactMem() + } + }) +} + +func TestDB_ManifestWriteError(t *testing.T) { + for i := 0; i < 2; i++ { + func() { + h := newDbHarness(t) + defer h.close() + + h.put("foo", "bar") + h.getVal("foo", "bar") + + // Mem compaction (will succeed) + h.compactMem() + h.getVal("foo", "bar") + v := h.db.s.version() + if n := v.tLen(h.o.GetMaxMemCompationLevel()); n != 1 { + t.Errorf("invalid total tables, want=1 got=%d", n) + } + v.release() + + if i == 0 { + h.stor.SetEmuErr(storage.TypeManifest, tsOpWrite) + } else { + h.stor.SetEmuErr(storage.TypeManifest, tsOpSync) + } + + // Merging compaction (will fail) + h.compactRangeAtErr(h.o.GetMaxMemCompationLevel(), "", "", true) + + h.db.Close() + h.stor.SetEmuErr(0, tsOpWrite) + h.stor.SetEmuErr(0, tsOpSync) + + // Should not lose data + h.openDB() + h.getVal("foo", "bar") + }() + } +} + +func assertErr(t *testing.T, err error, wanterr bool) { + if err != nil { + if wanterr { + t.Log("AssertErr: got error (expected): ", err) + } else { + t.Error("AssertErr: got error: ", err) + } + } else if wanterr { + t.Error("AssertErr: expect error") + } +} + +func TestDB_ClosedIsClosed(t *testing.T) { + h := newDbHarness(t) + db := h.db + + var iter, iter2 iterator.Iterator + var snap *Snapshot + func() { + defer h.close() + + h.put("k", "v") + h.getVal("k", "v") + + iter = db.NewIterator(nil, h.ro) + iter.Seek([]byte("k")) + testKeyVal(t, iter, "k->v") + + var err error + snap, err = db.GetSnapshot() + if err != nil { + t.Fatal("GetSnapshot: got error: ", err) + } + + h.getValr(snap, "k", "v") + + iter2 = snap.NewIterator(nil, h.ro) + iter2.Seek([]byte("k")) + testKeyVal(t, iter2, "k->v") + + h.put("foo", "v2") + h.delete("foo") + + // closing DB + iter.Release() + iter2.Release() + }() + + assertErr(t, db.Put([]byte("x"), []byte("y"), h.wo), true) + _, err := db.Get([]byte("k"), h.ro) + assertErr(t, err, true) + + if iter.Valid() { + t.Errorf("iter.Valid should false") + } + assertErr(t, iter.Error(), false) + testKeyVal(t, iter, "->") + if iter.Seek([]byte("k")) { + t.Errorf("iter.Seek should false") + } + assertErr(t, iter.Error(), true) + + assertErr(t, iter2.Error(), false) + + _, err = snap.Get([]byte("k"), h.ro) + assertErr(t, err, true) + + _, err = db.GetSnapshot() + assertErr(t, err, true) + + iter3 := db.NewIterator(nil, h.ro) + assertErr(t, iter3.Error(), true) + + iter3 = snap.NewIterator(nil, h.ro) + assertErr(t, iter3.Error(), true) + + assertErr(t, db.Delete([]byte("k"), h.wo), true) + + _, err = db.GetProperty("leveldb.stats") + assertErr(t, err, true) + + _, err = db.SizeOf([]util.Range{{[]byte("a"), []byte("z")}}) + assertErr(t, err, true) + + assertErr(t, db.CompactRange(util.Range{}), true) + + assertErr(t, db.Close(), true) +} + +type numberComparer struct{} + +func (numberComparer) num(x []byte) (n int) { + fmt.Sscan(string(x[1:len(x)-1]), &n) + return +} + +func (numberComparer) Name() string { + return "test.NumberComparer" +} + +func (p numberComparer) Compare(a, b []byte) int { + return p.num(a) - p.num(b) +} + +func (numberComparer) Separator(dst, a, b []byte) []byte { return nil } +func (numberComparer) Successor(dst, b []byte) []byte { return nil } + +func TestDB_CustomComparer(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + Comparer: numberComparer{}, + WriteBuffer: 1000, + }) + defer h.close() + + h.put("[10]", "ten") + h.put("[0x14]", "twenty") + for i := 0; i < 2; i++ { + h.getVal("[10]", "ten") + h.getVal("[0xa]", "ten") + h.getVal("[20]", "twenty") + h.getVal("[0x14]", "twenty") + h.get("[15]", false) + h.get("[0xf]", false) + h.compactMem() + h.compactRange("[0]", "[9999]") + } + + for n := 0; n < 2; n++ { + for i := 0; i < 100; i++ { + v := fmt.Sprintf("[%d]", i*10) + h.put(v, v) + } + h.compactMem() + h.compactRange("[0]", "[1000000]") + } +} + +func TestDB_ManualCompaction(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + if h.o.GetMaxMemCompationLevel() != 2 { + t.Fatal("fix test to reflect the config") + } + + h.putMulti(3, "p", "q") + h.tablesPerLevel("1,1,1") + + // Compaction range falls before files + h.compactRange("", "c") + h.tablesPerLevel("1,1,1") + + // Compaction range falls after files + h.compactRange("r", "z") + h.tablesPerLevel("1,1,1") + + // Compaction range overlaps files + h.compactRange("p1", "p9") + h.tablesPerLevel("0,0,1") + + // Populate a different range + h.putMulti(3, "c", "e") + h.tablesPerLevel("1,1,2") + + // Compact just the new range + h.compactRange("b", "f") + h.tablesPerLevel("0,0,2") + + // Compact all + h.putMulti(1, "a", "z") + h.tablesPerLevel("0,1,2") + h.compactRange("", "") + h.tablesPerLevel("0,0,1") +} + +func TestDB_BloomFilter(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + DisableBlockCache: true, + Filter: filter.NewBloomFilter(10), + }) + defer h.close() + + key := func(i int) string { + return fmt.Sprintf("key%06d", i) + } + + const n = 10000 + + // Populate multiple layers + for i := 0; i < n; i++ { + h.put(key(i), key(i)) + } + h.compactMem() + h.compactRange("a", "z") + for i := 0; i < n; i += 100 { + h.put(key(i), key(i)) + } + h.compactMem() + + // Prevent auto compactions triggered by seeks + h.stor.DelaySync(storage.TypeTable) + + // Lookup present keys. Should rarely read from small sstable. + h.stor.SetReadCounter(storage.TypeTable) + for i := 0; i < n; i++ { + h.getVal(key(i), key(i)) + } + cnt := int(h.stor.ReadCounter()) + t.Logf("lookup of %d present keys yield %d sstable I/O reads", n, cnt) + + if min, max := n, n+2*n/100; cnt < min || cnt > max { + t.Errorf("num of sstable I/O reads of present keys not in range of %d - %d, got %d", min, max, cnt) + } + + // Lookup missing keys. Should rarely read from either sstable. + h.stor.ResetReadCounter() + for i := 0; i < n; i++ { + h.get(key(i)+".missing", false) + } + cnt = int(h.stor.ReadCounter()) + t.Logf("lookup of %d missing keys yield %d sstable I/O reads", n, cnt) + if max := 3 * n / 100; cnt > max { + t.Errorf("num of sstable I/O reads of missing keys was more than %d, got %d", max, cnt) + } + + h.stor.ReleaseSync(storage.TypeTable) +} + +func TestDB_Concurrent(t *testing.T) { + const n, secs, maxkey = 4, 2, 1000 + + runtime.GOMAXPROCS(n) + trun(t, func(h *dbHarness) { + var closeWg sync.WaitGroup + var stop uint32 + var cnt [n]uint32 + + for i := 0; i < n; i++ { + closeWg.Add(1) + go func(i int) { + var put, get, found uint + defer func() { + t.Logf("goroutine %d stopped after %d ops, put=%d get=%d found=%d missing=%d", + i, cnt[i], put, get, found, get-found) + closeWg.Done() + }() + + rnd := rand.New(rand.NewSource(int64(1000 + i))) + for atomic.LoadUint32(&stop) == 0 { + x := cnt[i] + + k := rnd.Intn(maxkey) + kstr := fmt.Sprintf("%016d", k) + + if (rnd.Int() % 2) > 0 { + put++ + h.put(kstr, fmt.Sprintf("%d.%d.%-1000d", k, i, x)) + } else { + get++ + v, err := h.db.Get([]byte(kstr), h.ro) + if err == nil { + found++ + rk, ri, rx := 0, -1, uint32(0) + fmt.Sscanf(string(v), "%d.%d.%d", &rk, &ri, &rx) + if rk != k { + t.Errorf("invalid key want=%d got=%d", k, rk) + } + if ri < 0 || ri >= n { + t.Error("invalid goroutine number: ", ri) + } else { + tx := atomic.LoadUint32(&(cnt[ri])) + if rx > tx { + t.Errorf("invalid seq number, %d > %d ", rx, tx) + } + } + } else if err != ErrNotFound { + t.Error("Get: got error: ", err) + return + } + } + atomic.AddUint32(&cnt[i], 1) + } + }(i) + } + + time.Sleep(secs * time.Second) + atomic.StoreUint32(&stop, 1) + closeWg.Wait() + }) + + runtime.GOMAXPROCS(1) +} + +func TestDB_Concurrent2(t *testing.T) { + const n, n2 = 4, 4000 + + runtime.GOMAXPROCS(n*2 + 2) + truno(t, &opt.Options{WriteBuffer: 30}, func(h *dbHarness) { + var closeWg sync.WaitGroup + var stop uint32 + + for i := 0; i < n; i++ { + closeWg.Add(1) + go func(i int) { + for k := 0; atomic.LoadUint32(&stop) == 0; k++ { + h.put(fmt.Sprintf("k%d", k), fmt.Sprintf("%d.%d.", k, i)+strings.Repeat("x", 10)) + } + closeWg.Done() + }(i) + } + + for i := 0; i < n; i++ { + closeWg.Add(1) + go func(i int) { + for k := 1000000; k < 0 || atomic.LoadUint32(&stop) == 0; k-- { + h.put(fmt.Sprintf("k%d", k), fmt.Sprintf("%d.%d.", k, i)+strings.Repeat("x", 10)) + } + closeWg.Done() + }(i) + } + + cmp := comparer.DefaultComparer + for i := 0; i < n2; i++ { + closeWg.Add(1) + go func(i int) { + it := h.db.NewIterator(nil, nil) + var pk []byte + for it.Next() { + kk := it.Key() + if cmp.Compare(kk, pk) <= 0 { + t.Errorf("iter %d: %q is successor of %q", i, pk, kk) + } + pk = append(pk[:0], kk...) + var k, vk, vi int + if n, err := fmt.Sscanf(string(it.Key()), "k%d", &k); err != nil { + t.Errorf("iter %d: Scanf error on key %q: %v", i, it.Key(), err) + } else if n < 1 { + t.Errorf("iter %d: Cannot parse key %q", i, it.Key()) + } + if n, err := fmt.Sscanf(string(it.Value()), "%d.%d", &vk, &vi); err != nil { + t.Errorf("iter %d: Scanf error on value %q: %v", i, it.Value(), err) + } else if n < 2 { + t.Errorf("iter %d: Cannot parse value %q", i, it.Value()) + } + + if vk != k { + t.Errorf("iter %d: invalid value i=%d, want=%d got=%d", i, vi, k, vk) + } + } + if err := it.Error(); err != nil { + t.Errorf("iter %d: Got error: %v", i, err) + } + it.Release() + closeWg.Done() + }(i) + } + + atomic.StoreUint32(&stop, 1) + closeWg.Wait() + }) + + runtime.GOMAXPROCS(1) +} + +func TestDB_CreateReopenDbOnFile(t *testing.T) { + dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbtestCreateReopenDbOnFile-%d", os.Getuid())) + if err := os.RemoveAll(dbpath); err != nil { + t.Fatal("cannot remove old db: ", err) + } + defer os.RemoveAll(dbpath) + + for i := 0; i < 3; i++ { + stor, err := storage.OpenFile(dbpath) + if err != nil { + t.Fatalf("(%d) cannot open storage: %s", i, err) + } + db, err := Open(stor, nil) + if err != nil { + t.Fatalf("(%d) cannot open db: %s", i, err) + } + if err := db.Put([]byte("foo"), []byte("bar"), nil); err != nil { + t.Fatalf("(%d) cannot write to db: %s", i, err) + } + if err := db.Close(); err != nil { + t.Fatalf("(%d) cannot close db: %s", i, err) + } + if err := stor.Close(); err != nil { + t.Fatalf("(%d) cannot close storage: %s", i, err) + } + } +} + +func TestDB_CreateReopenDbOnFile2(t *testing.T) { + dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbtestCreateReopenDbOnFile2-%d", os.Getuid())) + if err := os.RemoveAll(dbpath); err != nil { + t.Fatal("cannot remove old db: ", err) + } + defer os.RemoveAll(dbpath) + + for i := 0; i < 3; i++ { + db, err := OpenFile(dbpath, nil) + if err != nil { + t.Fatalf("(%d) cannot open db: %s", i, err) + } + if err := db.Put([]byte("foo"), []byte("bar"), nil); err != nil { + t.Fatalf("(%d) cannot write to db: %s", i, err) + } + if err := db.Close(); err != nil { + t.Fatalf("(%d) cannot close db: %s", i, err) + } + } +} + +func TestDB_DeletionMarkersOnMemdb(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("foo", "v1") + h.compactMem() + h.delete("foo") + h.get("foo", false) + h.getKeyVal("") +} + +func TestDB_LeveldbIssue178(t *testing.T) { + nKeys := (opt.DefaultCompactionTableSize / 30) * 5 + key1 := func(i int) string { + return fmt.Sprintf("my_key_%d", i) + } + key2 := func(i int) string { + return fmt.Sprintf("my_key_%d_xxx", i) + } + + // Disable compression since it affects the creation of layers and the + // code below is trying to test against a very specific scenario. + h := newDbHarnessWopt(t, &opt.Options{Compression: opt.NoCompression}) + defer h.close() + + // Create first key range. + batch := new(Batch) + for i := 0; i < nKeys; i++ { + batch.Put([]byte(key1(i)), []byte("value for range 1 key")) + } + h.write(batch) + + // Create second key range. + batch.Reset() + for i := 0; i < nKeys; i++ { + batch.Put([]byte(key2(i)), []byte("value for range 2 key")) + } + h.write(batch) + + // Delete second key range. + batch.Reset() + for i := 0; i < nKeys; i++ { + batch.Delete([]byte(key2(i))) + } + h.write(batch) + h.waitMemCompaction() + + // Run manual compaction. + h.compactRange(key1(0), key1(nKeys-1)) + + // Checking the keys. + h.assertNumKeys(nKeys) +} + +func TestDB_LeveldbIssue200(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + h.put("1", "b") + h.put("2", "c") + h.put("3", "d") + h.put("4", "e") + h.put("5", "f") + + iter := h.db.NewIterator(nil, h.ro) + + // Add an element that should not be reflected in the iterator. + h.put("25", "cd") + + iter.Seek([]byte("5")) + assertBytes(t, []byte("5"), iter.Key()) + iter.Prev() + assertBytes(t, []byte("4"), iter.Key()) + iter.Prev() + assertBytes(t, []byte("3"), iter.Key()) + iter.Next() + assertBytes(t, []byte("4"), iter.Key()) + iter.Next() + assertBytes(t, []byte("5"), iter.Key()) +} + +func TestDB_GoleveldbIssue74(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + WriteBuffer: 1 * opt.MiB, + }) + defer h.close() + + const n, dur = 10000, 5 * time.Second + + runtime.GOMAXPROCS(runtime.NumCPU()) + + until := time.Now().Add(dur) + wg := new(sync.WaitGroup) + wg.Add(2) + var done uint32 + go func() { + var i int + defer func() { + t.Logf("WRITER DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + + b := new(Batch) + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + iv := fmt.Sprintf("VAL%010d", i) + for k := 0; k < n; k++ { + key := fmt.Sprintf("KEY%06d", k) + b.Put([]byte(key), []byte(key+iv)) + b.Put([]byte(fmt.Sprintf("PTR%06d", k)), []byte(key)) + } + h.write(b) + + b.Reset() + snap := h.getSnapshot() + iter := snap.NewIterator(util.BytesPrefix([]byte("PTR")), nil) + var k int + for ; iter.Next(); k++ { + ptrKey := iter.Key() + key := iter.Value() + + if _, err := snap.Get(ptrKey, nil); err != nil { + t.Fatalf("WRITER #%d snapshot.Get %q: %v", i, ptrKey, err) + } + if value, err := snap.Get(key, nil); err != nil { + t.Fatalf("WRITER #%d snapshot.Get %q: %v", i, key, err) + } else if string(value) != string(key)+iv { + t.Fatalf("WRITER #%d snapshot.Get %q got invalid value, want %q got %q", i, key, string(key)+iv, value) + } + + b.Delete(key) + b.Delete(ptrKey) + } + h.write(b) + iter.Release() + snap.Release() + if k != n { + t.Fatalf("#%d %d != %d", i, k, n) + } + } + }() + go func() { + var i int + defer func() { + t.Logf("READER DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + snap := h.getSnapshot() + iter := snap.NewIterator(util.BytesPrefix([]byte("PTR")), nil) + var prevValue string + var k int + for ; iter.Next(); k++ { + ptrKey := iter.Key() + key := iter.Value() + + if _, err := snap.Get(ptrKey, nil); err != nil { + t.Fatalf("READER #%d snapshot.Get %q: %v", i, ptrKey, err) + } + + if value, err := snap.Get(key, nil); err != nil { + t.Fatalf("READER #%d snapshot.Get %q: %v", i, key, err) + } else if prevValue != "" && string(value) != string(key)+prevValue { + t.Fatalf("READER #%d snapshot.Get %q got invalid value, want %q got %q", i, key, string(key)+prevValue, value) + } else { + prevValue = string(value[len(key):]) + } + } + iter.Release() + snap.Release() + if k > 0 && k != n { + t.Fatalf("#%d %d != %d", i, k, n) + } + } + }() + wg.Wait() +} + +func TestDB_GetProperties(t *testing.T) { + h := newDbHarness(t) + defer h.close() + + _, err := h.db.GetProperty("leveldb.num-files-at-level") + if err == nil { + t.Error("GetProperty() failed to detect missing level") + } + + _, err = h.db.GetProperty("leveldb.num-files-at-level0") + if err != nil { + t.Error("got unexpected error", err) + } + + _, err = h.db.GetProperty("leveldb.num-files-at-level0x") + if err == nil { + t.Error("GetProperty() failed to detect invalid level") + } +} + +func TestDB_GoleveldbIssue72and83(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + WriteBuffer: 1 * opt.MiB, + OpenFilesCacheCapacity: 3, + }) + defer h.close() + + const n, wn, dur = 10000, 100, 30 * time.Second + + runtime.GOMAXPROCS(runtime.NumCPU()) + + randomData := func(prefix byte, i int) []byte { + data := make([]byte, 1+4+32+64+32) + _, err := crand.Reader.Read(data[1 : len(data)-8]) + if err != nil { + panic(err) + } + data[0] = prefix + binary.LittleEndian.PutUint32(data[len(data)-8:], uint32(i)) + binary.LittleEndian.PutUint32(data[len(data)-4:], util.NewCRC(data[:len(data)-4]).Value()) + return data + } + + keys := make([][]byte, n) + for i := range keys { + keys[i] = randomData(1, 0) + } + + until := time.Now().Add(dur) + wg := new(sync.WaitGroup) + wg.Add(3) + var done uint32 + go func() { + i := 0 + defer func() { + t.Logf("WRITER DONE #%d", i) + wg.Done() + }() + + b := new(Batch) + for ; i < wn && atomic.LoadUint32(&done) == 0; i++ { + b.Reset() + for _, k1 := range keys { + k2 := randomData(2, i) + b.Put(k2, randomData(42, i)) + b.Put(k1, k2) + } + if err := h.db.Write(b, h.wo); err != nil { + atomic.StoreUint32(&done, 1) + t.Fatalf("WRITER #%d db.Write: %v", i, err) + } + } + }() + go func() { + var i int + defer func() { + t.Logf("READER0 DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + snap := h.getSnapshot() + seq := snap.elem.seq + if seq == 0 { + snap.Release() + continue + } + iter := snap.NewIterator(util.BytesPrefix([]byte{1}), nil) + writei := int(seq/(n*2) - 1) + var k int + for ; iter.Next(); k++ { + k1 := iter.Key() + k2 := iter.Value() + k1checksum0 := binary.LittleEndian.Uint32(k1[len(k1)-4:]) + k1checksum1 := util.NewCRC(k1[:len(k1)-4]).Value() + if k1checksum0 != k1checksum1 { + t.Fatalf("READER0 #%d.%d W#%d invalid K1 checksum: %#x != %#x", i, k, k1checksum0, k1checksum0) + } + k2checksum0 := binary.LittleEndian.Uint32(k2[len(k2)-4:]) + k2checksum1 := util.NewCRC(k2[:len(k2)-4]).Value() + if k2checksum0 != k2checksum1 { + t.Fatalf("READER0 #%d.%d W#%d invalid K2 checksum: %#x != %#x", i, k, k2checksum0, k2checksum1) + } + kwritei := int(binary.LittleEndian.Uint32(k2[len(k2)-8:])) + if writei != kwritei { + t.Fatalf("READER0 #%d.%d W#%d invalid write iteration num: %d", i, k, writei, kwritei) + } + if _, err := snap.Get(k2, nil); err != nil { + t.Fatalf("READER0 #%d.%d W#%d snap.Get: %v\nk1: %x\n -> k2: %x", i, k, writei, err, k1, k2) + } + } + if err := iter.Error(); err != nil { + t.Fatalf("READER0 #%d.%d W#%d snap.Iterator: %v", i, k, writei, err) + } + iter.Release() + snap.Release() + if k > 0 && k != n { + t.Fatalf("READER0 #%d W#%d short read, got=%d want=%d", i, writei, k, n) + } + } + }() + go func() { + var i int + defer func() { + t.Logf("READER1 DONE #%d", i) + atomic.StoreUint32(&done, 1) + wg.Done() + }() + for ; time.Now().Before(until) && atomic.LoadUint32(&done) == 0; i++ { + iter := h.db.NewIterator(nil, nil) + seq := iter.(*dbIter).seq + if seq == 0 { + iter.Release() + continue + } + writei := int(seq/(n*2) - 1) + var k int + for ok := iter.Last(); ok; ok = iter.Prev() { + k++ + } + if err := iter.Error(); err != nil { + t.Fatalf("READER1 #%d.%d W#%d db.Iterator: %v", i, k, writei, err) + } + iter.Release() + if m := (writei+1)*n + n; k != m { + t.Fatalf("READER1 #%d W#%d short read, got=%d want=%d", i, writei, k, m) + } + } + }() + + wg.Wait() +} + +func TestDB_TransientError(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + WriteBuffer: 128 * opt.KiB, + OpenFilesCacheCapacity: 3, + DisableCompactionBackoff: true, + }) + defer h.close() + + const ( + nSnap = 20 + nKey = 10000 + ) + + var ( + snaps [nSnap]*Snapshot + b = &Batch{} + ) + for i := range snaps { + vtail := fmt.Sprintf("VAL%030d", i) + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%8d", k) + b.Put([]byte(key), []byte(key+vtail)) + } + h.stor.SetEmuRandErr(storage.TypeTable, tsOpOpen, tsOpRead, tsOpReadAt) + if err := h.db.Write(b, nil); err != nil { + t.Logf("WRITE #%d error: %v", i, err) + h.stor.SetEmuRandErr(0, tsOpOpen, tsOpRead, tsOpReadAt, tsOpWrite) + for { + if err := h.db.Write(b, nil); err == nil { + break + } else if errors.IsCorrupted(err) { + t.Fatalf("WRITE #%d corrupted: %v", i, err) + } + } + } + + snaps[i] = h.db.newSnapshot() + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%8d", k) + b.Delete([]byte(key)) + } + h.stor.SetEmuRandErr(storage.TypeTable, tsOpOpen, tsOpRead, tsOpReadAt) + if err := h.db.Write(b, nil); err != nil { + t.Logf("WRITE #%d error: %v", i, err) + h.stor.SetEmuRandErr(0, tsOpOpen, tsOpRead, tsOpReadAt) + for { + if err := h.db.Write(b, nil); err == nil { + break + } else if errors.IsCorrupted(err) { + t.Fatalf("WRITE #%d corrupted: %v", i, err) + } + } + } + } + h.stor.SetEmuRandErr(0, tsOpOpen, tsOpRead, tsOpReadAt) + + runtime.GOMAXPROCS(runtime.NumCPU()) + + rnd := rand.New(rand.NewSource(0xecafdaed)) + wg := &sync.WaitGroup{} + for i, snap := range snaps { + wg.Add(2) + + go func(i int, snap *Snapshot, sk []int) { + defer wg.Done() + + vtail := fmt.Sprintf("VAL%030d", i) + for _, k := range sk { + key := fmt.Sprintf("KEY%8d", k) + xvalue, err := snap.Get([]byte(key), nil) + if err != nil { + t.Fatalf("READER_GET #%d SEQ=%d K%d error: %v", i, snap.elem.seq, k, err) + } + value := key + vtail + if !bytes.Equal([]byte(value), xvalue) { + t.Fatalf("READER_GET #%d SEQ=%d K%d invalid value: want %q, got %q", i, snap.elem.seq, k, value, xvalue) + } + } + }(i, snap, rnd.Perm(nKey)) + + go func(i int, snap *Snapshot) { + defer wg.Done() + + vtail := fmt.Sprintf("VAL%030d", i) + iter := snap.NewIterator(nil, nil) + defer iter.Release() + for k := 0; k < nKey; k++ { + if !iter.Next() { + if err := iter.Error(); err != nil { + t.Fatalf("READER_ITER #%d K%d error: %v", i, k, err) + } else { + t.Fatalf("READER_ITER #%d K%d eoi", i, k) + } + } + key := fmt.Sprintf("KEY%8d", k) + xkey := iter.Key() + if !bytes.Equal([]byte(key), xkey) { + t.Fatalf("READER_ITER #%d K%d invalid key: want %q, got %q", i, k, key, xkey) + } + value := key + vtail + xvalue := iter.Value() + if !bytes.Equal([]byte(value), xvalue) { + t.Fatalf("READER_ITER #%d K%d invalid value: want %q, got %q", i, k, value, xvalue) + } + } + }(i, snap) + } + + wg.Wait() +} + +func TestDB_UkeyShouldntHopAcrossTable(t *testing.T) { + h := newDbHarnessWopt(t, &opt.Options{ + WriteBuffer: 112 * opt.KiB, + CompactionTableSize: 90 * opt.KiB, + CompactionExpandLimitFactor: 1, + }) + defer h.close() + + const ( + nSnap = 190 + nKey = 140 + ) + + var ( + snaps [nSnap]*Snapshot + b = &Batch{} + ) + for i := range snaps { + vtail := fmt.Sprintf("VAL%030d", i) + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%08d", k) + b.Put([]byte(key), []byte(key+vtail)) + } + if err := h.db.Write(b, nil); err != nil { + t.Fatalf("WRITE #%d error: %v", i, err) + } + + snaps[i] = h.db.newSnapshot() + b.Reset() + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%08d", k) + b.Delete([]byte(key)) + } + if err := h.db.Write(b, nil); err != nil { + t.Fatalf("WRITE #%d error: %v", i, err) + } + } + + h.compactMem() + + h.waitCompaction() + for level, tables := range h.db.s.stVersion.tables { + for _, table := range tables { + t.Logf("L%d@%d %q:%q", level, table.file.Num(), table.imin, table.imax) + } + } + + h.compactRangeAt(0, "", "") + h.waitCompaction() + for level, tables := range h.db.s.stVersion.tables { + for _, table := range tables { + t.Logf("L%d@%d %q:%q", level, table.file.Num(), table.imin, table.imax) + } + } + h.compactRangeAt(1, "", "") + h.waitCompaction() + for level, tables := range h.db.s.stVersion.tables { + for _, table := range tables { + t.Logf("L%d@%d %q:%q", level, table.file.Num(), table.imin, table.imax) + } + } + runtime.GOMAXPROCS(runtime.NumCPU()) + + wg := &sync.WaitGroup{} + for i, snap := range snaps { + wg.Add(1) + + go func(i int, snap *Snapshot) { + defer wg.Done() + + vtail := fmt.Sprintf("VAL%030d", i) + for k := 0; k < nKey; k++ { + key := fmt.Sprintf("KEY%08d", k) + xvalue, err := snap.Get([]byte(key), nil) + if err != nil { + t.Fatalf("READER_GET #%d SEQ=%d K%d error: %v", i, snap.elem.seq, k, err) + } + value := key + vtail + if !bytes.Equal([]byte(value), xvalue) { + t.Fatalf("READER_GET #%d SEQ=%d K%d invalid value: want %q, got %q", i, snap.elem.seq, k, value, xvalue) + } + } + }(i, snap) + } + + wg.Wait() +} + +func TestDB_TableCompactionBuilder(t *testing.T) { + stor := newTestStorage(t) + defer stor.Close() + + const nSeq = 99 + + o := &opt.Options{ + WriteBuffer: 112 * opt.KiB, + CompactionTableSize: 43 * opt.KiB, + CompactionExpandLimitFactor: 1, + CompactionGPOverlapsFactor: 1, + DisableBlockCache: true, + } + s, err := newSession(stor, o) + if err != nil { + t.Fatal(err) + } + if err := s.create(); err != nil { + t.Fatal(err) + } + defer s.close() + var ( + seq uint64 + targetSize = 5 * o.CompactionTableSize + value = bytes.Repeat([]byte{'0'}, 100) + ) + for i := 0; i < 2; i++ { + tw, err := s.tops.create() + if err != nil { + t.Fatal(err) + } + for k := 0; tw.tw.BytesLen() < targetSize; k++ { + key := []byte(fmt.Sprintf("%09d", k)) + seq += nSeq - 1 + for x := uint64(0); x < nSeq; x++ { + if err := tw.append(newIkey(key, seq-x, ktVal), value); err != nil { + t.Fatal(err) + } + } + } + tf, err := tw.finish() + if err != nil { + t.Fatal(err) + } + rec := &sessionRecord{numLevel: s.o.GetNumLevel()} + rec.addTableFile(i, tf) + if err := s.commit(rec); err != nil { + t.Fatal(err) + } + } + + // Build grandparent. + v := s.version() + c := newCompaction(s, v, 1, append(tFiles{}, v.tables[1]...)) + rec := &sessionRecord{numLevel: s.o.GetNumLevel()} + b := &tableCompactionBuilder{ + s: s, + c: c, + rec: rec, + stat1: new(cStatsStaging), + minSeq: 0, + strict: true, + tableSize: o.CompactionTableSize/3 + 961, + } + if err := b.run(new(compactionTransactCounter)); err != nil { + t.Fatal(err) + } + for _, t := range c.tables[0] { + rec.delTable(c.level, t.file.Num()) + } + if err := s.commit(rec); err != nil { + t.Fatal(err) + } + c.release() + + // Build level-1. + v = s.version() + c = newCompaction(s, v, 0, append(tFiles{}, v.tables[0]...)) + rec = &sessionRecord{numLevel: s.o.GetNumLevel()} + b = &tableCompactionBuilder{ + s: s, + c: c, + rec: rec, + stat1: new(cStatsStaging), + minSeq: 0, + strict: true, + tableSize: o.CompactionTableSize, + } + if err := b.run(new(compactionTransactCounter)); err != nil { + t.Fatal(err) + } + for _, t := range c.tables[0] { + rec.delTable(c.level, t.file.Num()) + } + // Move grandparent to level-3 + for _, t := range v.tables[2] { + rec.delTable(2, t.file.Num()) + rec.addTableFile(3, t) + } + if err := s.commit(rec); err != nil { + t.Fatal(err) + } + c.release() + + v = s.version() + for level, want := range []bool{false, true, false, true, false} { + got := len(v.tables[level]) > 0 + if want != got { + t.Fatalf("invalid level-%d tables len: want %v, got %v", level, want, got) + } + } + for i, f := range v.tables[1][:len(v.tables[1])-1] { + nf := v.tables[1][i+1] + if bytes.Equal(f.imax.ukey(), nf.imin.ukey()) { + t.Fatalf("KEY %q hop across table %d .. %d", f.imax.ukey(), f.file.Num(), nf.file.Num()) + } + } + v.release() + + // Compaction with transient error. + v = s.version() + c = newCompaction(s, v, 1, append(tFiles{}, v.tables[1]...)) + rec = &sessionRecord{numLevel: s.o.GetNumLevel()} + b = &tableCompactionBuilder{ + s: s, + c: c, + rec: rec, + stat1: new(cStatsStaging), + minSeq: 0, + strict: true, + tableSize: o.CompactionTableSize, + } + stor.SetEmuErrOnce(storage.TypeTable, tsOpSync) + stor.SetEmuRandErr(storage.TypeTable, tsOpRead, tsOpReadAt, tsOpWrite) + stor.SetEmuRandErrProb(0xf0) + for { + if err := b.run(new(compactionTransactCounter)); err != nil { + t.Logf("(expected) b.run: %v", err) + } else { + break + } + } + if err := s.commit(rec); err != nil { + t.Fatal(err) + } + c.release() + + stor.SetEmuErrOnce(0, tsOpSync) + stor.SetEmuRandErr(0, tsOpRead, tsOpReadAt, tsOpWrite) + + v = s.version() + if len(v.tables[1]) != len(v.tables[2]) { + t.Fatalf("invalid tables length, want %d, got %d", len(v.tables[1]), len(v.tables[2])) + } + for i, f0 := range v.tables[1] { + f1 := v.tables[2][i] + iter0 := s.tops.newIterator(f0, nil, nil) + iter1 := s.tops.newIterator(f1, nil, nil) + for j := 0; true; j++ { + next0 := iter0.Next() + next1 := iter1.Next() + if next0 != next1 { + t.Fatalf("#%d.%d invalid eoi: want %v, got %v", i, j, next0, next1) + } + key0 := iter0.Key() + key1 := iter1.Key() + if !bytes.Equal(key0, key1) { + t.Fatalf("#%d.%d invalid key: want %q, got %q", i, j, key0, key1) + } + if next0 == false { + break + } + } + iter0.Release() + iter1.Release() + } + v.release() +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_util.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_util.go new file mode 100644 index 0000000000..a8a2bdf72e --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_util.go @@ -0,0 +1,100 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// Reader is the interface that wraps basic Get and NewIterator methods. +// This interface implemented by both DB and Snapshot. +type Reader interface { + Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) + NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator +} + +type Sizes []uint64 + +// Sum returns sum of the sizes. +func (p Sizes) Sum() (n uint64) { + for _, s := range p { + n += s + } + return n +} + +// Logging. +func (db *DB) log(v ...interface{}) { db.s.log(v...) } +func (db *DB) logf(format string, v ...interface{}) { db.s.logf(format, v...) } + +// Check and clean files. +func (db *DB) checkAndCleanFiles() error { + v := db.s.version() + defer v.release() + + tablesMap := make(map[uint64]bool) + for _, tables := range v.tables { + for _, t := range tables { + tablesMap[t.file.Num()] = false + } + } + + files, err := db.s.getFiles(storage.TypeAll) + if err != nil { + return err + } + + var nTables int + var rem []storage.File + for _, f := range files { + keep := true + switch f.Type() { + case storage.TypeManifest: + keep = f.Num() >= db.s.manifestFile.Num() + case storage.TypeJournal: + if db.frozenJournalFile != nil { + keep = f.Num() >= db.frozenJournalFile.Num() + } else { + keep = f.Num() >= db.journalFile.Num() + } + case storage.TypeTable: + _, keep = tablesMap[f.Num()] + if keep { + tablesMap[f.Num()] = true + nTables++ + } + } + + if !keep { + rem = append(rem, f) + } + } + + if nTables != len(tablesMap) { + var missing []*storage.FileInfo + for num, present := range tablesMap { + if !present { + missing = append(missing, &storage.FileInfo{Type: storage.TypeTable, Num: num}) + db.logf("db@janitor table missing @%d", num) + } + } + return errors.NewErrCorrupted(nil, &errors.ErrMissingFiles{Files: missing}) + } + + db.logf("db@janitor F·%d G·%d", len(files), len(rem)) + for _, f := range rem { + db.logf("db@janitor removing %s-%d", f.Type(), f.Num()) + if err := f.Remove(); err != nil { + return err + } + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_write.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_write.go new file mode 100644 index 0000000000..e1cf30c536 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/db_write.go @@ -0,0 +1,311 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "time" + + "github.com/syndtr/goleveldb/leveldb/memdb" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func (db *DB) writeJournal(b *Batch) error { + w, err := db.journal.Next() + if err != nil { + return err + } + if _, err := w.Write(b.encode()); err != nil { + return err + } + if err := db.journal.Flush(); err != nil { + return err + } + if b.sync { + return db.journalWriter.Sync() + } + return nil +} + +func (db *DB) jWriter() { + defer db.closeW.Done() + for { + select { + case b := <-db.journalC: + if b != nil { + db.journalAckC <- db.writeJournal(b) + } + case _, _ = <-db.closeC: + return + } + } +} + +func (db *DB) rotateMem(n int) (mem *memDB, err error) { + // Wait for pending memdb compaction. + err = db.compSendIdle(db.mcompCmdC) + if err != nil { + return + } + + // Create new memdb and journal. + mem, err = db.newMem(n) + if err != nil { + return + } + + // Schedule memdb compaction. + db.compSendTrigger(db.mcompCmdC) + return +} + +func (db *DB) flush(n int) (mem *memDB, nn int, err error) { + delayed := false + flush := func() (retry bool) { + v := db.s.version() + defer v.release() + mem = db.getEffectiveMem() + defer func() { + if retry { + mem.decref() + mem = nil + } + }() + nn = mem.mdb.Free() + switch { + case v.tLen(0) >= db.s.o.GetWriteL0SlowdownTrigger() && !delayed: + delayed = true + time.Sleep(time.Millisecond) + case nn >= n: + return false + case v.tLen(0) >= db.s.o.GetWriteL0PauseTrigger(): + delayed = true + err = db.compSendIdle(db.tcompCmdC) + if err != nil { + return false + } + default: + // Allow memdb to grow if it has no entry. + if mem.mdb.Len() == 0 { + nn = n + } else { + mem.decref() + mem, err = db.rotateMem(n) + if err == nil { + nn = mem.mdb.Free() + } else { + nn = 0 + } + } + return false + } + return true + } + start := time.Now() + for flush() { + } + if delayed { + db.writeDelay += time.Since(start) + db.writeDelayN++ + } else if db.writeDelayN > 0 { + db.logf("db@write was delayed N·%d T·%v", db.writeDelayN, db.writeDelay) + db.writeDelay = 0 + db.writeDelayN = 0 + } + return +} + +// Write apply the given batch to the DB. The batch will be applied +// sequentially. +// +// It is safe to modify the contents of the arguments after Write returns. +func (db *DB) Write(b *Batch, wo *opt.WriteOptions) (err error) { + err = db.ok() + if err != nil || b == nil || b.Len() == 0 { + return + } + + b.init(wo.GetSync()) + + // The write happen synchronously. + select { + case db.writeC <- b: + if <-db.writeMergedC { + return <-db.writeAckC + } + case db.writeLockC <- struct{}{}: + case err = <-db.compPerErrC: + return + case _, _ = <-db.closeC: + return ErrClosed + } + + merged := 0 + danglingMerge := false + defer func() { + if danglingMerge { + db.writeMergedC <- false + } else { + <-db.writeLockC + } + for i := 0; i < merged; i++ { + db.writeAckC <- err + } + }() + + mem, memFree, err := db.flush(b.size()) + if err != nil { + return + } + defer mem.decref() + + // Calculate maximum size of the batch. + m := 1 << 20 + if x := b.size(); x <= 128<<10 { + m = x + (128 << 10) + } + m = minInt(m, memFree) + + // Merge with other batch. +drain: + for b.size() < m && !b.sync { + select { + case nb := <-db.writeC: + if b.size()+nb.size() <= m { + b.append(nb) + db.writeMergedC <- true + merged++ + } else { + danglingMerge = true + break drain + } + default: + break drain + } + } + + // Set batch first seq number relative from last seq. + b.seq = db.seq + 1 + + // Write journal concurrently if it is large enough. + if b.size() >= (128 << 10) { + // Push the write batch to the journal writer + select { + case db.journalC <- b: + // Write into memdb + if berr := b.memReplay(mem.mdb); berr != nil { + panic(berr) + } + case err = <-db.compPerErrC: + return + case _, _ = <-db.closeC: + err = ErrClosed + return + } + // Wait for journal writer + select { + case err = <-db.journalAckC: + if err != nil { + // Revert memdb if error detected + if berr := b.revertMemReplay(mem.mdb); berr != nil { + panic(berr) + } + return + } + case _, _ = <-db.closeC: + err = ErrClosed + return + } + } else { + err = db.writeJournal(b) + if err != nil { + return + } + if berr := b.memReplay(mem.mdb); berr != nil { + panic(berr) + } + } + + // Set last seq number. + db.addSeq(uint64(b.Len())) + + if b.size() >= memFree { + db.rotateMem(0) + } + return +} + +// Put sets the value for the given key. It overwrites any previous value +// for that key; a DB is not a multi-map. +// +// It is safe to modify the contents of the arguments after Put returns. +func (db *DB) Put(key, value []byte, wo *opt.WriteOptions) error { + b := new(Batch) + b.Put(key, value) + return db.Write(b, wo) +} + +// Delete deletes the value for the given key. It returns ErrNotFound if +// the DB does not contain the key. +// +// It is safe to modify the contents of the arguments after Delete returns. +func (db *DB) Delete(key []byte, wo *opt.WriteOptions) error { + b := new(Batch) + b.Delete(key) + return db.Write(b, wo) +} + +func isMemOverlaps(icmp *iComparer, mem *memdb.DB, min, max []byte) bool { + iter := mem.NewIterator(nil) + defer iter.Release() + return (max == nil || (iter.First() && icmp.uCompare(max, iKey(iter.Key()).ukey()) >= 0)) && + (min == nil || (iter.Last() && icmp.uCompare(min, iKey(iter.Key()).ukey()) <= 0)) +} + +// CompactRange compacts the underlying DB for the given key range. +// In particular, deleted and overwritten versions are discarded, +// and the data is rearranged to reduce the cost of operations +// needed to access the data. This operation should typically only +// be invoked by users who understand the underlying implementation. +// +// A nil Range.Start is treated as a key before all keys in the DB. +// And a nil Range.Limit is treated as a key after all keys in the DB. +// Therefore if both is nil then it will compact entire DB. +func (db *DB) CompactRange(r util.Range) error { + if err := db.ok(); err != nil { + return err + } + + // Lock writer. + select { + case db.writeLockC <- struct{}{}: + case err := <-db.compPerErrC: + return err + case _, _ = <-db.closeC: + return ErrClosed + } + + // Check for overlaps in memdb. + mem := db.getEffectiveMem() + defer mem.decref() + if isMemOverlaps(db.s.icmp, mem.mdb, r.Start, r.Limit) { + // Memdb compaction. + if _, err := db.rotateMem(0); err != nil { + <-db.writeLockC + return err + } + <-db.writeLockC + if err := db.compSendIdle(db.mcompCmdC); err != nil { + return err + } + } else { + <-db.writeLockC + } + + // Table compaction. + return db.compSendRange(db.tcompCmdC, -1, r.Start, r.Limit) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/doc.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/doc.go new file mode 100644 index 0000000000..53f13bb24c --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/doc.go @@ -0,0 +1,90 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package leveldb provides implementation of LevelDB key/value database. +// +// Create or open a database: +// +// db, err := leveldb.OpenFile("path/to/db", nil) +// ... +// defer db.Close() +// ... +// +// Read or modify the database content: +// +// // Remember that the contents of the returned slice should not be modified. +// data, err := db.Get([]byte("key"), nil) +// ... +// err = db.Put([]byte("key"), []byte("value"), nil) +// ... +// err = db.Delete([]byte("key"), nil) +// ... +// +// Iterate over database content: +// +// iter := db.NewIterator(nil, nil) +// for iter.Next() { +// // Remember that the contents of the returned slice should not be modified, and +// // only valid until the next call to Next. +// key := iter.Key() +// value := iter.Value() +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Iterate over subset of database content with a particular prefix: +// iter := db.NewIterator(util.BytesPrefix([]byte("foo-")), nil) +// for iter.Next() { +// // Use key/value. +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Seek-then-Iterate: +// +// iter := db.NewIterator(nil, nil) +// for ok := iter.Seek(key); ok; ok = iter.Next() { +// // Use key/value. +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Iterate over subset of database content: +// +// iter := db.NewIterator(&util.Range{Start: []byte("foo"), Limit: []byte("xoo")}, nil) +// for iter.Next() { +// // Use key/value. +// ... +// } +// iter.Release() +// err = iter.Error() +// ... +// +// Batch writes: +// +// batch := new(leveldb.Batch) +// batch.Put([]byte("foo"), []byte("value")) +// batch.Put([]byte("bar"), []byte("another value")) +// batch.Delete([]byte("baz")) +// err = db.Write(batch, nil) +// ... +// +// Use bloom filter: +// +// o := &opt.Options{ +// Filter: filter.NewBloomFilter(10), +// } +// db, err := leveldb.OpenFile("path/to/db", o) +// ... +// defer db.Close() +// ... +package leveldb diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors.go new file mode 100644 index 0000000000..29d0d2f279 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors.go @@ -0,0 +1,18 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/errors" +) + +var ( + ErrNotFound = errors.ErrNotFound + ErrSnapshotReleased = errors.New("leveldb: snapshot released") + ErrIterReleased = errors.New("leveldb: iterator released") + ErrClosed = errors.New("leveldb: closed") +) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors/errors.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors/errors.go new file mode 100644 index 0000000000..84b5d6b7b2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/errors/errors.go @@ -0,0 +1,76 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package errors provides common error types used throughout leveldb. +package errors + +import ( + "errors" + "fmt" + + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + ErrNotFound = New("leveldb: not found") + ErrReleased = util.ErrReleased + ErrHasReleaser = util.ErrHasReleaser +) + +// New returns an error that formats as the given text. +func New(text string) error { + return errors.New(text) +} + +// ErrCorrupted is the type that wraps errors that indicate corruption in +// the database. +type ErrCorrupted struct { + File *storage.FileInfo + Err error +} + +func (e *ErrCorrupted) Error() string { + if e.File != nil { + return fmt.Sprintf("%v [file=%v]", e.Err, e.File) + } else { + return e.Err.Error() + } +} + +// NewErrCorrupted creates new ErrCorrupted error. +func NewErrCorrupted(f storage.File, err error) error { + return &ErrCorrupted{storage.NewFileInfo(f), err} +} + +// IsCorrupted returns a boolean indicating whether the error is indicating +// a corruption. +func IsCorrupted(err error) bool { + switch err.(type) { + case *ErrCorrupted: + return true + } + return false +} + +// ErrMissingFiles is the type that indicating a corruption due to missing +// files. +type ErrMissingFiles struct { + Files []*storage.FileInfo +} + +func (e *ErrMissingFiles) Error() string { return "file missing" } + +// SetFile sets 'file info' of the given error with the given file. +// Currently only ErrCorrupted is supported, otherwise will do nothing. +func SetFile(err error, f storage.File) error { + switch x := err.(type) { + case *ErrCorrupted: + x.File = storage.NewFileInfo(f) + return x + } + return err +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/external_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/external_test.go new file mode 100644 index 0000000000..b328ece4e2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/external_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +var _ = testutil.Defer(func() { + Describe("Leveldb external", func() { + o := &opt.Options{ + DisableBlockCache: true, + BlockRestartInterval: 5, + BlockSize: 80, + Compression: opt.NoCompression, + OpenFilesCacheCapacity: -1, + Strict: opt.StrictAll, + WriteBuffer: 1000, + CompactionTableSize: 2000, + } + + Describe("write test", func() { + It("should do write correctly", func(done Done) { + db := newTestingDB(o, nil, nil) + t := testutil.DBTesting{ + DB: db, + Deleted: testutil.KeyValue_Generate(nil, 500, 1, 50, 5, 5).Clone(), + } + testutil.DoDBTesting(&t) + db.TestClose() + done <- true + }, 20.0) + }) + + Describe("read test", func() { + testutil.AllKeyValueTesting(nil, nil, func(kv testutil.KeyValue) testutil.DB { + // Building the DB. + db := newTestingDB(o, nil, nil) + kv.IterateShuffled(nil, func(i int, key, value []byte) { + err := db.TestPut(key, value) + Expect(err).NotTo(HaveOccurred()) + }) + + return db + }, func(db testutil.DB) { + db.(*testingDB).TestClose() + }) + }) + }) +}) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter.go new file mode 100644 index 0000000000..37c1e146bc --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter.go @@ -0,0 +1,31 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/filter" +) + +type iFilter struct { + filter.Filter +} + +func (f iFilter) Contains(filter, key []byte) bool { + return f.Filter.Contains(filter, iKey(key).ukey()) +} + +func (f iFilter) NewGenerator() filter.FilterGenerator { + return iFilterGenerator{f.Filter.NewGenerator()} +} + +type iFilterGenerator struct { + filter.FilterGenerator +} + +func (g iFilterGenerator) Add(key []byte) { + g.FilterGenerator.Add(iKey(key).ukey()) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom.go new file mode 100644 index 0000000000..bab0e99705 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom.go @@ -0,0 +1,116 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package filter + +import ( + "github.com/syndtr/goleveldb/leveldb/util" +) + +func bloomHash(key []byte) uint32 { + return util.Hash(key, 0xbc9f1d34) +} + +type bloomFilter int + +// The bloom filter serializes its parameters and is backward compatible +// with respect to them. Therefor, its parameters are not added to its +// name. +func (bloomFilter) Name() string { + return "leveldb.BuiltinBloomFilter" +} + +func (f bloomFilter) Contains(filter, key []byte) bool { + nBytes := len(filter) - 1 + if nBytes < 1 { + return false + } + nBits := uint32(nBytes * 8) + + // Use the encoded k so that we can read filters generated by + // bloom filters created using different parameters. + k := filter[nBytes] + if k > 30 { + // Reserved for potentially new encodings for short bloom filters. + // Consider it a match. + return true + } + + kh := bloomHash(key) + delta := (kh >> 17) | (kh << 15) // Rotate right 17 bits + for j := uint8(0); j < k; j++ { + bitpos := kh % nBits + if (uint32(filter[bitpos/8]) & (1 << (bitpos % 8))) == 0 { + return false + } + kh += delta + } + return true +} + +func (f bloomFilter) NewGenerator() FilterGenerator { + // Round down to reduce probing cost a little bit. + k := uint8(f * 69 / 100) // 0.69 =~ ln(2) + if k < 1 { + k = 1 + } else if k > 30 { + k = 30 + } + return &bloomFilterGenerator{ + n: int(f), + k: k, + } +} + +type bloomFilterGenerator struct { + n int + k uint8 + + keyHashes []uint32 +} + +func (g *bloomFilterGenerator) Add(key []byte) { + // Use double-hashing to generate a sequence of hash values. + // See analysis in [Kirsch,Mitzenmacher 2006]. + g.keyHashes = append(g.keyHashes, bloomHash(key)) +} + +func (g *bloomFilterGenerator) Generate(b Buffer) { + // Compute bloom filter size (in both bits and bytes) + nBits := uint32(len(g.keyHashes) * g.n) + // For small n, we can see a very high false positive rate. Fix it + // by enforcing a minimum bloom filter length. + if nBits < 64 { + nBits = 64 + } + nBytes := (nBits + 7) / 8 + nBits = nBytes * 8 + + dest := b.Alloc(int(nBytes) + 1) + dest[nBytes] = g.k + for _, kh := range g.keyHashes { + delta := (kh >> 17) | (kh << 15) // Rotate right 17 bits + for j := uint8(0); j < g.k; j++ { + bitpos := kh % nBits + dest[bitpos/8] |= (1 << (bitpos % 8)) + kh += delta + } + } + + g.keyHashes = g.keyHashes[:0] +} + +// NewBloomFilter creates a new initialized bloom filter for given +// bitsPerKey. +// +// Since bitsPerKey is persisted individually for each bloom filter +// serialization, bloom filters are backwards compatible with respect to +// changing bitsPerKey. This means that no big performance penalty will +// be experienced when changing the parameter. See documentation for +// opt.Options.Filter for more information. +func NewBloomFilter(bitsPerKey int) Filter { + return bloomFilter(bitsPerKey) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go new file mode 100644 index 0000000000..1fb56f0713 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/bloom_test.go @@ -0,0 +1,142 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package filter + +import ( + "encoding/binary" + "github.com/syndtr/goleveldb/leveldb/util" + "testing" +) + +type harness struct { + t *testing.T + + bloom Filter + generator FilterGenerator + filter []byte +} + +func newHarness(t *testing.T) *harness { + bloom := NewBloomFilter(10) + return &harness{ + t: t, + bloom: bloom, + generator: bloom.NewGenerator(), + } +} + +func (h *harness) add(key []byte) { + h.generator.Add(key) +} + +func (h *harness) addNum(key uint32) { + var b [4]byte + binary.LittleEndian.PutUint32(b[:], key) + h.add(b[:]) +} + +func (h *harness) build() { + b := &util.Buffer{} + h.generator.Generate(b) + h.filter = b.Bytes() +} + +func (h *harness) reset() { + h.filter = nil +} + +func (h *harness) filterLen() int { + return len(h.filter) +} + +func (h *harness) assert(key []byte, want, silent bool) bool { + got := h.bloom.Contains(h.filter, key) + if !silent && got != want { + h.t.Errorf("assert on '%v' failed got '%v', want '%v'", key, got, want) + } + return got +} + +func (h *harness) assertNum(key uint32, want, silent bool) bool { + var b [4]byte + binary.LittleEndian.PutUint32(b[:], key) + return h.assert(b[:], want, silent) +} + +func TestBloomFilter_Empty(t *testing.T) { + h := newHarness(t) + h.build() + h.assert([]byte("hello"), false, false) + h.assert([]byte("world"), false, false) +} + +func TestBloomFilter_Small(t *testing.T) { + h := newHarness(t) + h.add([]byte("hello")) + h.add([]byte("world")) + h.build() + h.assert([]byte("hello"), true, false) + h.assert([]byte("world"), true, false) + h.assert([]byte("x"), false, false) + h.assert([]byte("foo"), false, false) +} + +func nextN(n int) int { + switch { + case n < 10: + n += 1 + case n < 100: + n += 10 + case n < 1000: + n += 100 + default: + n += 1000 + } + return n +} + +func TestBloomFilter_VaryingLengths(t *testing.T) { + h := newHarness(t) + var mediocre, good int + for n := 1; n < 10000; n = nextN(n) { + h.reset() + for i := 0; i < n; i++ { + h.addNum(uint32(i)) + } + h.build() + + got := h.filterLen() + want := (n * 10 / 8) + 40 + if got > want { + t.Errorf("filter len test failed, '%d' > '%d'", got, want) + } + + for i := 0; i < n; i++ { + h.assertNum(uint32(i), true, false) + } + + var rate float32 + for i := 0; i < 10000; i++ { + if h.assertNum(uint32(i+1000000000), true, true) { + rate++ + } + } + rate /= 10000 + if rate > 0.02 { + t.Errorf("false positive rate is more than 2%%, got %v, at len %d", rate, n) + } + if rate > 0.0125 { + mediocre++ + } else { + good++ + } + } + t.Logf("false positive rate: %d good, %d mediocre", good, mediocre) + if mediocre > good/5 { + t.Error("mediocre false positive rate is more than expected") + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/filter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/filter.go new file mode 100644 index 0000000000..7a925c5a86 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/filter/filter.go @@ -0,0 +1,60 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package filter provides interface and implementation of probabilistic +// data structure. +// +// The filter is resposible for creating small filter from a set of keys. +// These filter will then used to test whether a key is a member of the set. +// In many cases, a filter can cut down the number of disk seeks from a +// handful to a single disk seek per DB.Get call. +package filter + +// Buffer is the interface that wraps basic Alloc, Write and WriteByte methods. +type Buffer interface { + // Alloc allocs n bytes of slice from the buffer. This also advancing + // write offset. + Alloc(n int) []byte + + // Write appends the contents of p to the buffer. + Write(p []byte) (n int, err error) + + // WriteByte appends the byte c to the buffer. + WriteByte(c byte) error +} + +// Filter is the filter. +type Filter interface { + // Name returns the name of this policy. + // + // Note that if the filter encoding changes in an incompatible way, + // the name returned by this method must be changed. Otherwise, old + // incompatible filters may be passed to methods of this type. + Name() string + + // NewGenerator creates a new filter generator. + NewGenerator() FilterGenerator + + // Contains returns true if the filter contains the given key. + // + // The filter are filters generated by the filter generator. + Contains(filter, key []byte) bool +} + +// FilterGenerator is the filter generator. +type FilterGenerator interface { + // Add adds a key to the filter generator. + // + // The key may become invalid after call to this method end, therefor + // key must be copied if implementation require keeping key for later + // use. The key should not modified directly, doing so may cause + // undefined results. + Add(key []byte) + + // Generate generates filters based on keys passed so far. After call + // to Generate the filter generator maybe resetted, depends on implementation. + Generate(b Buffer) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/go13_bench_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/go13_bench_test.go new file mode 100644 index 0000000000..e76657e5ec --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/go13_bench_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build go1.3 + +package leveldb + +import ( + "sync/atomic" + "testing" +) + +func BenchmarkDBReadConcurrent(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + defer p.close() + + b.ResetTimer() + b.SetBytes(116) + + b.RunParallel(func(pb *testing.PB) { + iter := p.newIter() + defer iter.Release() + for pb.Next() && iter.Next() { + } + }) +} + +func BenchmarkDBReadConcurrent2(b *testing.B) { + p := openDBBench(b, false) + p.populate(b.N) + p.fill() + p.gc() + defer p.close() + + b.ResetTimer() + b.SetBytes(116) + + var dir uint32 + b.RunParallel(func(pb *testing.PB) { + iter := p.newIter() + defer iter.Release() + if atomic.AddUint32(&dir, 1)%2 == 0 { + for pb.Next() && iter.Next() { + } + } else { + if pb.Next() && iter.Last() { + for pb.Next() && iter.Prev() { + } + } + } + }) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go new file mode 100644 index 0000000000..a23ab05f70 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter.go @@ -0,0 +1,184 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator + +import ( + "github.com/syndtr/goleveldb/leveldb/util" +) + +// BasicArray is the interface that wraps basic Len and Search method. +type BasicArray interface { + // Len returns length of the array. + Len() int + + // Search finds smallest index that point to a key that is greater + // than or equal to the given key. + Search(key []byte) int +} + +// Array is the interface that wraps BasicArray and basic Index method. +type Array interface { + BasicArray + + // Index returns key/value pair with index of i. + Index(i int) (key, value []byte) +} + +// Array is the interface that wraps BasicArray and basic Get method. +type ArrayIndexer interface { + BasicArray + + // Get returns a new data iterator with index of i. + Get(i int) Iterator +} + +type basicArrayIterator struct { + util.BasicReleaser + array BasicArray + pos int + err error +} + +func (i *basicArrayIterator) Valid() bool { + return i.pos >= 0 && i.pos < i.array.Len() && !i.Released() +} + +func (i *basicArrayIterator) First() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + if i.array.Len() == 0 { + i.pos = -1 + return false + } + i.pos = 0 + return true +} + +func (i *basicArrayIterator) Last() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + n := i.array.Len() + if n == 0 { + i.pos = 0 + return false + } + i.pos = n - 1 + return true +} + +func (i *basicArrayIterator) Seek(key []byte) bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + n := i.array.Len() + if n == 0 { + i.pos = 0 + return false + } + i.pos = i.array.Search(key) + if i.pos >= n { + return false + } + return true +} + +func (i *basicArrayIterator) Next() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.pos++ + if n := i.array.Len(); i.pos >= n { + i.pos = n + return false + } + return true +} + +func (i *basicArrayIterator) Prev() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.pos-- + if i.pos < 0 { + i.pos = -1 + return false + } + return true +} + +func (i *basicArrayIterator) Error() error { return i.err } + +type arrayIterator struct { + basicArrayIterator + array Array + pos int + key, value []byte +} + +func (i *arrayIterator) updateKV() { + if i.pos == i.basicArrayIterator.pos { + return + } + i.pos = i.basicArrayIterator.pos + if i.Valid() { + i.key, i.value = i.array.Index(i.pos) + } else { + i.key = nil + i.value = nil + } +} + +func (i *arrayIterator) Key() []byte { + i.updateKV() + return i.key +} + +func (i *arrayIterator) Value() []byte { + i.updateKV() + return i.value +} + +type arrayIteratorIndexer struct { + basicArrayIterator + array ArrayIndexer +} + +func (i *arrayIteratorIndexer) Get() Iterator { + if i.Valid() { + return i.array.Get(i.basicArrayIterator.pos) + } + return nil +} + +// NewArrayIterator returns an iterator from the given array. +func NewArrayIterator(array Array) Iterator { + return &arrayIterator{ + basicArrayIterator: basicArrayIterator{array: array, pos: -1}, + array: array, + pos: -1, + } +} + +// NewArrayIndexer returns an index iterator from the given array. +func NewArrayIndexer(array ArrayIndexer) IteratorIndexer { + return &arrayIteratorIndexer{ + basicArrayIterator: basicArrayIterator{array: array, pos: -1}, + array: array, + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go new file mode 100644 index 0000000000..1ed6d07cbb --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/array_iter_test.go @@ -0,0 +1,30 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator_test + +import ( + . "github.com/onsi/ginkgo" + + . "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +var _ = testutil.Defer(func() { + Describe("Array iterator", func() { + It("Should iterates and seeks correctly", func() { + // Build key/value. + kv := testutil.KeyValue_Generate(nil, 70, 1, 5, 3, 3) + + // Test the iterator. + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: NewArrayIterator(kv), + } + testutil.DoIteratorTesting(&t) + }) + }) +}) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go new file mode 100644 index 0000000000..939adbb933 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter.go @@ -0,0 +1,242 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator + +import ( + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// IteratorIndexer is the interface that wraps CommonIterator and basic Get +// method. IteratorIndexer provides index for indexed iterator. +type IteratorIndexer interface { + CommonIterator + + // Get returns a new data iterator for the current position, or nil if + // done. + Get() Iterator +} + +type indexedIterator struct { + util.BasicReleaser + index IteratorIndexer + strict bool + + data Iterator + err error + errf func(err error) + closed bool +} + +func (i *indexedIterator) setData() { + if i.data != nil { + i.data.Release() + } + i.data = i.index.Get() +} + +func (i *indexedIterator) clearData() { + if i.data != nil { + i.data.Release() + } + i.data = nil +} + +func (i *indexedIterator) indexErr() { + if err := i.index.Error(); err != nil { + if i.errf != nil { + i.errf(err) + } + i.err = err + } +} + +func (i *indexedIterator) dataErr() bool { + if err := i.data.Error(); err != nil { + if i.errf != nil { + i.errf(err) + } + if i.strict || !errors.IsCorrupted(err) { + i.err = err + return true + } + } + return false +} + +func (i *indexedIterator) Valid() bool { + return i.data != nil && i.data.Valid() +} + +func (i *indexedIterator) First() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + if !i.index.First() { + i.indexErr() + i.clearData() + return false + } + i.setData() + return i.Next() +} + +func (i *indexedIterator) Last() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + if !i.index.Last() { + i.indexErr() + i.clearData() + return false + } + i.setData() + if !i.data.Last() { + if i.dataErr() { + return false + } + i.clearData() + return i.Prev() + } + return true +} + +func (i *indexedIterator) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + if !i.index.Seek(key) { + i.indexErr() + i.clearData() + return false + } + i.setData() + if !i.data.Seek(key) { + if i.dataErr() { + return false + } + i.clearData() + return i.Next() + } + return true +} + +func (i *indexedIterator) Next() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + switch { + case i.data != nil && !i.data.Next(): + if i.dataErr() { + return false + } + i.clearData() + fallthrough + case i.data == nil: + if !i.index.Next() { + i.indexErr() + return false + } + i.setData() + return i.Next() + } + return true +} + +func (i *indexedIterator) Prev() bool { + if i.err != nil { + return false + } else if i.Released() { + i.err = ErrIterReleased + return false + } + + switch { + case i.data != nil && !i.data.Prev(): + if i.dataErr() { + return false + } + i.clearData() + fallthrough + case i.data == nil: + if !i.index.Prev() { + i.indexErr() + return false + } + i.setData() + if !i.data.Last() { + if i.dataErr() { + return false + } + i.clearData() + return i.Prev() + } + } + return true +} + +func (i *indexedIterator) Key() []byte { + if i.data == nil { + return nil + } + return i.data.Key() +} + +func (i *indexedIterator) Value() []byte { + if i.data == nil { + return nil + } + return i.data.Value() +} + +func (i *indexedIterator) Release() { + i.clearData() + i.index.Release() + i.BasicReleaser.Release() +} + +func (i *indexedIterator) Error() error { + if i.err != nil { + return i.err + } + if err := i.index.Error(); err != nil { + return err + } + return nil +} + +func (i *indexedIterator) SetErrorCallback(f func(err error)) { + i.errf = f +} + +// NewIndexedIterator returns an 'indexed iterator'. An index is iterator +// that returns another iterator, a 'data iterator'. A 'data iterator' is the +// iterator that contains actual key/value pairs. +// +// If strict is true the any 'corruption errors' (i.e errors.IsCorrupted(err) == true) +// won't be ignored and will halt 'indexed iterator', otherwise the iterator will +// continue to the next 'data iterator'. Corruption on 'index iterator' will not be +// ignored and will halt the iterator. +func NewIndexedIterator(index IteratorIndexer, strict bool) Iterator { + return &indexedIterator{index: index, strict: strict} +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go new file mode 100644 index 0000000000..72a7978924 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/indexed_iter_test.go @@ -0,0 +1,83 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator_test + +import ( + "sort" + + . "github.com/onsi/ginkgo" + + "github.com/syndtr/goleveldb/leveldb/comparer" + . "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +type keyValue struct { + key []byte + testutil.KeyValue +} + +type keyValueIndex []keyValue + +func (x keyValueIndex) Search(key []byte) int { + return sort.Search(x.Len(), func(i int) bool { + return comparer.DefaultComparer.Compare(x[i].key, key) >= 0 + }) +} + +func (x keyValueIndex) Len() int { return len(x) } +func (x keyValueIndex) Index(i int) (key, value []byte) { return x[i].key, nil } +func (x keyValueIndex) Get(i int) Iterator { return NewArrayIterator(x[i]) } + +var _ = testutil.Defer(func() { + Describe("Indexed iterator", func() { + Test := func(n ...int) func() { + if len(n) == 0 { + rnd := testutil.NewRand() + n = make([]int, rnd.Intn(17)+3) + for i := range n { + n[i] = rnd.Intn(19) + 1 + } + } + + return func() { + It("Should iterates and seeks correctly", func(done Done) { + // Build key/value. + index := make(keyValueIndex, len(n)) + sum := 0 + for _, x := range n { + sum += x + } + kv := testutil.KeyValue_Generate(nil, sum, 1, 10, 4, 4) + for i, j := 0, 0; i < len(n); i++ { + for x := n[i]; x > 0; x-- { + key, value := kv.Index(j) + index[i].key = key + index[i].Put(key, value) + j++ + } + } + + // Test the iterator. + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: NewIndexedIterator(NewArrayIndexer(index), true), + } + testutil.DoIteratorTesting(&t) + done <- true + }, 1.5) + } + } + + Describe("with 100 keys", Test(100)) + Describe("with 50-50 keys", Test(50, 50)) + Describe("with 50-1 keys", Test(50, 1)) + Describe("with 50-1-50 keys", Test(50, 1, 50)) + Describe("with 1-50 keys", Test(1, 50)) + Describe("with random N-keys", Test()) + }) +}) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter.go new file mode 100644 index 0000000000..c2522860b0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter.go @@ -0,0 +1,131 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package iterator provides interface and implementation to traverse over +// contents of a database. +package iterator + +import ( + "errors" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + ErrIterReleased = errors.New("leveldb/iterator: iterator released") +) + +// IteratorSeeker is the interface that wraps the 'seeks method'. +type IteratorSeeker interface { + // First moves the iterator to the first key/value pair. If the iterator + // only contains one key/value pair then First and Last whould moves + // to the same key/value pair. + // It returns whether such pair exist. + First() bool + + // Last moves the iterator to the last key/value pair. If the iterator + // only contains one key/value pair then First and Last whould moves + // to the same key/value pair. + // It returns whether such pair exist. + Last() bool + + // Seek moves the iterator to the first key/value pair whose key is greater + // than or equal to the given key. + // It returns whether such pair exist. + // + // It is safe to modify the contents of the argument after Seek returns. + Seek(key []byte) bool + + // Next moves the iterator to the next key/value pair. + // It returns whether the iterator is exhausted. + Next() bool + + // Prev moves the iterator to the previous key/value pair. + // It returns whether the iterator is exhausted. + Prev() bool +} + +// CommonIterator is the interface that wraps common interator methods. +type CommonIterator interface { + IteratorSeeker + + // util.Releaser is the interface that wraps basic Release method. + // When called Release will releases any resources associated with the + // iterator. + util.Releaser + + // util.ReleaseSetter is the interface that wraps the basic SetReleaser + // method. + util.ReleaseSetter + + // TODO: Remove this when ready. + Valid() bool + + // Error returns any accumulated error. Exhausting all the key/value pairs + // is not considered to be an error. + Error() error +} + +// Iterator iterates over a DB's key/value pairs in key order. +// +// When encouter an error any 'seeks method' will return false and will +// yield no key/value pairs. The error can be queried by calling the Error +// method. Calling Release is still necessary. +// +// An iterator must be released after use, but it is not necessary to read +// an iterator until exhaustion. +// Also, an iterator is not necessarily goroutine-safe, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +type Iterator interface { + CommonIterator + + // Key returns the key of the current key/value pair, or nil if done. + // The caller should not modify the contents of the returned slice, and + // its contents may change on the next call to any 'seeks method'. + Key() []byte + + // Value returns the key of the current key/value pair, or nil if done. + // The caller should not modify the contents of the returned slice, and + // its contents may change on the next call to any 'seeks method'. + Value() []byte +} + +// ErrorCallbackSetter is the interface that wraps basic SetErrorCallback +// method. +// +// ErrorCallbackSetter implemented by indexed and merged iterator. +type ErrorCallbackSetter interface { + // SetErrorCallback allows set an error callback of the coresponding + // iterator. Use nil to clear the callback. + SetErrorCallback(f func(err error)) +} + +type emptyIterator struct { + util.BasicReleaser + err error +} + +func (i *emptyIterator) rErr() { + if i.err == nil && i.Released() { + i.err = ErrIterReleased + } +} + +func (*emptyIterator) Valid() bool { return false } +func (i *emptyIterator) First() bool { i.rErr(); return false } +func (i *emptyIterator) Last() bool { i.rErr(); return false } +func (i *emptyIterator) Seek(key []byte) bool { i.rErr(); return false } +func (i *emptyIterator) Next() bool { i.rErr(); return false } +func (i *emptyIterator) Prev() bool { i.rErr(); return false } +func (*emptyIterator) Key() []byte { return nil } +func (*emptyIterator) Value() []byte { return nil } +func (i *emptyIterator) Error() error { return i.err } + +// NewEmptyIterator creates an empty iterator. The err parameter can be +// nil, but if not nil the given err will be returned by Error method. +func NewEmptyIterator(err error) Iterator { + return &emptyIterator{err: err} +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go new file mode 100644 index 0000000000..5ef8d5bafb --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/iter_suite_test.go @@ -0,0 +1,11 @@ +package iterator_test + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestIterator(t *testing.T) { + testutil.RunSuite(t, "Iterator Suite") +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go new file mode 100644 index 0000000000..1a7e29df8f --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter.go @@ -0,0 +1,304 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator + +import ( + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type dir int + +const ( + dirReleased dir = iota - 1 + dirSOI + dirEOI + dirBackward + dirForward +) + +type mergedIterator struct { + cmp comparer.Comparer + iters []Iterator + strict bool + + keys [][]byte + index int + dir dir + err error + errf func(err error) + releaser util.Releaser +} + +func assertKey(key []byte) []byte { + if key == nil { + panic("leveldb/iterator: nil key") + } + return key +} + +func (i *mergedIterator) iterErr(iter Iterator) bool { + if err := iter.Error(); err != nil { + if i.errf != nil { + i.errf(err) + } + if i.strict || !errors.IsCorrupted(err) { + i.err = err + return true + } + } + return false +} + +func (i *mergedIterator) Valid() bool { + return i.err == nil && i.dir > dirEOI +} + +func (i *mergedIterator) First() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + for x, iter := range i.iters { + switch { + case iter.First(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + i.dir = dirSOI + return i.next() +} + +func (i *mergedIterator) Last() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + for x, iter := range i.iters { + switch { + case iter.Last(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + i.dir = dirEOI + return i.prev() +} + +func (i *mergedIterator) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + for x, iter := range i.iters { + switch { + case iter.Seek(key): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + i.dir = dirSOI + return i.next() +} + +func (i *mergedIterator) next() bool { + var key []byte + if i.dir == dirForward { + key = i.keys[i.index] + } + for x, tkey := range i.keys { + if tkey != nil && (key == nil || i.cmp.Compare(tkey, key) < 0) { + key = tkey + i.index = x + } + } + if key == nil { + i.dir = dirEOI + return false + } + i.dir = dirForward + return true +} + +func (i *mergedIterator) Next() bool { + if i.dir == dirEOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + switch i.dir { + case dirSOI: + return i.First() + case dirBackward: + key := append([]byte{}, i.keys[i.index]...) + if !i.Seek(key) { + return false + } + return i.Next() + } + + x := i.index + iter := i.iters[x] + switch { + case iter.Next(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + return i.next() +} + +func (i *mergedIterator) prev() bool { + var key []byte + if i.dir == dirBackward { + key = i.keys[i.index] + } + for x, tkey := range i.keys { + if tkey != nil && (key == nil || i.cmp.Compare(tkey, key) > 0) { + key = tkey + i.index = x + } + } + if key == nil { + i.dir = dirSOI + return false + } + i.dir = dirBackward + return true +} + +func (i *mergedIterator) Prev() bool { + if i.dir == dirSOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + switch i.dir { + case dirEOI: + return i.Last() + case dirForward: + key := append([]byte{}, i.keys[i.index]...) + for x, iter := range i.iters { + if x == i.index { + continue + } + seek := iter.Seek(key) + switch { + case seek && iter.Prev(), !seek && iter.Last(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + } + } + + x := i.index + iter := i.iters[x] + switch { + case iter.Prev(): + i.keys[x] = assertKey(iter.Key()) + case i.iterErr(iter): + return false + default: + i.keys[x] = nil + } + return i.prev() +} + +func (i *mergedIterator) Key() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.keys[i.index] +} + +func (i *mergedIterator) Value() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.iters[i.index].Value() +} + +func (i *mergedIterator) Release() { + if i.dir != dirReleased { + i.dir = dirReleased + for _, iter := range i.iters { + iter.Release() + } + i.iters = nil + i.keys = nil + if i.releaser != nil { + i.releaser.Release() + i.releaser = nil + } + } +} + +func (i *mergedIterator) SetReleaser(releaser util.Releaser) { + if i.dir == dirReleased { + panic(util.ErrReleased) + } + if i.releaser != nil && releaser != nil { + panic(util.ErrHasReleaser) + } + i.releaser = releaser +} + +func (i *mergedIterator) Error() error { + return i.err +} + +func (i *mergedIterator) SetErrorCallback(f func(err error)) { + i.errf = f +} + +// NewMergedIterator returns an iterator that merges its input. Walking the +// resultant iterator will return all key/value pairs of all input iterators +// in strictly increasing key order, as defined by cmp. +// The input's key ranges may overlap, but there are assumed to be no duplicate +// keys: if iters[i] contains a key k then iters[j] will not contain that key k. +// None of the iters may be nil. +// +// If strict is true the any 'corruption errors' (i.e errors.IsCorrupted(err) == true) +// won't be ignored and will halt 'merged iterator', otherwise the iterator will +// continue to the next 'input iterator'. +func NewMergedIterator(iters []Iterator, cmp comparer.Comparer, strict bool) Iterator { + return &mergedIterator{ + iters: iters, + cmp: cmp, + strict: strict, + keys: make([][]byte, len(iters)), + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go new file mode 100644 index 0000000000..e523b63e4b --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/iterator/merged_iter_test.go @@ -0,0 +1,60 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package iterator_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/comparer" + . "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +var _ = testutil.Defer(func() { + Describe("Merged iterator", func() { + Test := func(filled int, empty int) func() { + return func() { + It("Should iterates and seeks correctly", func(done Done) { + rnd := testutil.NewRand() + + // Build key/value. + filledKV := make([]testutil.KeyValue, filled) + kv := testutil.KeyValue_Generate(nil, 100, 1, 10, 4, 4) + kv.Iterate(func(i int, key, value []byte) { + filledKV[rnd.Intn(filled)].Put(key, value) + }) + + // Create itearators. + iters := make([]Iterator, filled+empty) + for i := range iters { + if empty == 0 || (rnd.Int()%2 == 0 && filled > 0) { + filled-- + Expect(filledKV[filled].Len()).ShouldNot(BeZero()) + iters[i] = NewArrayIterator(filledKV[filled]) + } else { + empty-- + iters[i] = NewEmptyIterator(nil) + } + } + + // Test the iterator. + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: NewMergedIterator(iters, comparer.DefaultComparer, true), + } + testutil.DoIteratorTesting(&t) + done <- true + }, 1.5) + } + } + + Describe("with three, all filled iterators", Test(3, 0)) + Describe("with one filled, one empty iterators", Test(1, 1)) + Describe("with one filled, two empty iterators", Test(1, 2)) + }) +}) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal.go new file mode 100644 index 0000000000..6519ec660e --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal.go @@ -0,0 +1,520 @@ +// Copyright 2011 The LevelDB-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Taken from: https://code.google.com/p/leveldb-go/source/browse/leveldb/record/record.go?r=1d5ccbe03246da926391ee12d1c6caae054ff4b0 +// License, authors and contributors informations can be found at bellow URLs respectively: +// https://code.google.com/p/leveldb-go/source/browse/LICENSE +// https://code.google.com/p/leveldb-go/source/browse/AUTHORS +// https://code.google.com/p/leveldb-go/source/browse/CONTRIBUTORS + +// Package journal reads and writes sequences of journals. Each journal is a stream +// of bytes that completes before the next journal starts. +// +// When reading, call Next to obtain an io.Reader for the next journal. Next will +// return io.EOF when there are no more journals. It is valid to call Next +// without reading the current journal to exhaustion. +// +// When writing, call Next to obtain an io.Writer for the next journal. Calling +// Next finishes the current journal. Call Close to finish the final journal. +// +// Optionally, call Flush to finish the current journal and flush the underlying +// writer without starting a new journal. To start a new journal after flushing, +// call Next. +// +// Neither Readers or Writers are safe to use concurrently. +// +// Example code: +// func read(r io.Reader) ([]string, error) { +// var ss []string +// journals := journal.NewReader(r, nil, true, true) +// for { +// j, err := journals.Next() +// if err == io.EOF { +// break +// } +// if err != nil { +// return nil, err +// } +// s, err := ioutil.ReadAll(j) +// if err != nil { +// return nil, err +// } +// ss = append(ss, string(s)) +// } +// return ss, nil +// } +// +// func write(w io.Writer, ss []string) error { +// journals := journal.NewWriter(w) +// for _, s := range ss { +// j, err := journals.Next() +// if err != nil { +// return err +// } +// if _, err := j.Write([]byte(s)), err != nil { +// return err +// } +// } +// return journals.Close() +// } +// +// The wire format is that the stream is divided into 32KiB blocks, and each +// block contains a number of tightly packed chunks. Chunks cannot cross block +// boundaries. The last block may be shorter than 32 KiB. Any unused bytes in a +// block must be zero. +// +// A journal maps to one or more chunks. Each chunk has a 7 byte header (a 4 +// byte checksum, a 2 byte little-endian uint16 length, and a 1 byte chunk type) +// followed by a payload. The checksum is over the chunk type and the payload. +// +// There are four chunk types: whether the chunk is the full journal, or the +// first, middle or last chunk of a multi-chunk journal. A multi-chunk journal +// has one first chunk, zero or more middle chunks, and one last chunk. +// +// The wire format allows for limited recovery in the face of data corruption: +// on a format error (such as a checksum mismatch), the reader moves to the +// next block and looks for the next full or first chunk. +package journal + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// These constants are part of the wire format and should not be changed. +const ( + fullChunkType = 1 + firstChunkType = 2 + middleChunkType = 3 + lastChunkType = 4 +) + +const ( + blockSize = 32 * 1024 + headerSize = 7 +) + +type flusher interface { + Flush() error +} + +// ErrCorrupted is the error type that generated by corrupted block or chunk. +type ErrCorrupted struct { + Size int + Reason string +} + +func (e *ErrCorrupted) Error() string { + return fmt.Sprintf("leveldb/journal: block/chunk corrupted: %s (%d bytes)", e.Reason, e.Size) +} + +// Dropper is the interface that wrap simple Drop method. The Drop +// method will be called when the journal reader dropping a block or chunk. +type Dropper interface { + Drop(err error) +} + +// Reader reads journals from an underlying io.Reader. +type Reader struct { + // r is the underlying reader. + r io.Reader + // the dropper. + dropper Dropper + // strict flag. + strict bool + // checksum flag. + checksum bool + // seq is the sequence number of the current journal. + seq int + // buf[i:j] is the unread portion of the current chunk's payload. + // The low bound, i, excludes the chunk header. + i, j int + // n is the number of bytes of buf that are valid. Once reading has started, + // only the final block can have n < blockSize. + n int + // last is whether the current chunk is the last chunk of the journal. + last bool + // err is any accumulated error. + err error + // buf is the buffer. + buf [blockSize]byte +} + +// NewReader returns a new reader. The dropper may be nil, and if +// strict is true then corrupted or invalid chunk will halt the journal +// reader entirely. +func NewReader(r io.Reader, dropper Dropper, strict, checksum bool) *Reader { + return &Reader{ + r: r, + dropper: dropper, + strict: strict, + checksum: checksum, + last: true, + } +} + +var errSkip = errors.New("leveldb/journal: skipped") + +func (r *Reader) corrupt(n int, reason string, skip bool) error { + if r.dropper != nil { + r.dropper.Drop(&ErrCorrupted{n, reason}) + } + if r.strict && !skip { + r.err = errors.NewErrCorrupted(nil, &ErrCorrupted{n, reason}) + return r.err + } + return errSkip +} + +// nextChunk sets r.buf[r.i:r.j] to hold the next chunk's payload, reading the +// next block into the buffer if necessary. +func (r *Reader) nextChunk(first bool) error { + for { + if r.j+headerSize <= r.n { + checksum := binary.LittleEndian.Uint32(r.buf[r.j+0 : r.j+4]) + length := binary.LittleEndian.Uint16(r.buf[r.j+4 : r.j+6]) + chunkType := r.buf[r.j+6] + + if checksum == 0 && length == 0 && chunkType == 0 { + // Drop entire block. + m := r.n - r.j + r.i = r.n + r.j = r.n + return r.corrupt(m, "zero header", false) + } else { + m := r.n - r.j + r.i = r.j + headerSize + r.j = r.j + headerSize + int(length) + if r.j > r.n { + // Drop entire block. + r.i = r.n + r.j = r.n + return r.corrupt(m, "chunk length overflows block", false) + } else if r.checksum && checksum != util.NewCRC(r.buf[r.i-1:r.j]).Value() { + // Drop entire block. + r.i = r.n + r.j = r.n + return r.corrupt(m, "checksum mismatch", false) + } + } + if first && chunkType != fullChunkType && chunkType != firstChunkType { + m := r.j - r.i + r.i = r.j + // Report the error, but skip it. + return r.corrupt(m+headerSize, "orphan chunk", true) + } + r.last = chunkType == fullChunkType || chunkType == lastChunkType + return nil + } + + // The last block. + if r.n < blockSize && r.n > 0 { + if !first { + return r.corrupt(0, "missing chunk part", false) + } + r.err = io.EOF + return r.err + } + + // Read block. + n, err := io.ReadFull(r.r, r.buf[:]) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return err + } + if n == 0 { + if !first { + return r.corrupt(0, "missing chunk part", false) + } + r.err = io.EOF + return r.err + } + r.i, r.j, r.n = 0, 0, n + } +} + +// Next returns a reader for the next journal. It returns io.EOF if there are no +// more journals. The reader returned becomes stale after the next Next call, +// and should no longer be used. If strict is false, the reader will returns +// io.ErrUnexpectedEOF error when found corrupted journal. +func (r *Reader) Next() (io.Reader, error) { + r.seq++ + if r.err != nil { + return nil, r.err + } + r.i = r.j + for { + if err := r.nextChunk(true); err == nil { + break + } else if err != errSkip { + return nil, err + } + } + return &singleReader{r, r.seq, nil}, nil +} + +// Reset resets the journal reader, allows reuse of the journal reader. Reset returns +// last accumulated error. +func (r *Reader) Reset(reader io.Reader, dropper Dropper, strict, checksum bool) error { + r.seq++ + err := r.err + r.r = reader + r.dropper = dropper + r.strict = strict + r.checksum = checksum + r.i = 0 + r.j = 0 + r.n = 0 + r.last = true + r.err = nil + return err +} + +type singleReader struct { + r *Reader + seq int + err error +} + +func (x *singleReader) Read(p []byte) (int, error) { + r := x.r + if r.seq != x.seq { + return 0, errors.New("leveldb/journal: stale reader") + } + if x.err != nil { + return 0, x.err + } + if r.err != nil { + return 0, r.err + } + for r.i == r.j { + if r.last { + return 0, io.EOF + } + x.err = r.nextChunk(false) + if x.err != nil { + if x.err == errSkip { + x.err = io.ErrUnexpectedEOF + } + return 0, x.err + } + } + n := copy(p, r.buf[r.i:r.j]) + r.i += n + return n, nil +} + +func (x *singleReader) ReadByte() (byte, error) { + r := x.r + if r.seq != x.seq { + return 0, errors.New("leveldb/journal: stale reader") + } + if x.err != nil { + return 0, x.err + } + if r.err != nil { + return 0, r.err + } + for r.i == r.j { + if r.last { + return 0, io.EOF + } + x.err = r.nextChunk(false) + if x.err != nil { + if x.err == errSkip { + x.err = io.ErrUnexpectedEOF + } + return 0, x.err + } + } + c := r.buf[r.i] + r.i++ + return c, nil +} + +// Writer writes journals to an underlying io.Writer. +type Writer struct { + // w is the underlying writer. + w io.Writer + // seq is the sequence number of the current journal. + seq int + // f is w as a flusher. + f flusher + // buf[i:j] is the bytes that will become the current chunk. + // The low bound, i, includes the chunk header. + i, j int + // buf[:written] has already been written to w. + // written is zero unless Flush has been called. + written int + // first is whether the current chunk is the first chunk of the journal. + first bool + // pending is whether a chunk is buffered but not yet written. + pending bool + // err is any accumulated error. + err error + // buf is the buffer. + buf [blockSize]byte +} + +// NewWriter returns a new Writer. +func NewWriter(w io.Writer) *Writer { + f, _ := w.(flusher) + return &Writer{ + w: w, + f: f, + } +} + +// fillHeader fills in the header for the pending chunk. +func (w *Writer) fillHeader(last bool) { + if w.i+headerSize > w.j || w.j > blockSize { + panic("leveldb/journal: bad writer state") + } + if last { + if w.first { + w.buf[w.i+6] = fullChunkType + } else { + w.buf[w.i+6] = lastChunkType + } + } else { + if w.first { + w.buf[w.i+6] = firstChunkType + } else { + w.buf[w.i+6] = middleChunkType + } + } + binary.LittleEndian.PutUint32(w.buf[w.i+0:w.i+4], util.NewCRC(w.buf[w.i+6:w.j]).Value()) + binary.LittleEndian.PutUint16(w.buf[w.i+4:w.i+6], uint16(w.j-w.i-headerSize)) +} + +// writeBlock writes the buffered block to the underlying writer, and reserves +// space for the next chunk's header. +func (w *Writer) writeBlock() { + _, w.err = w.w.Write(w.buf[w.written:]) + w.i = 0 + w.j = headerSize + w.written = 0 +} + +// writePending finishes the current journal and writes the buffer to the +// underlying writer. +func (w *Writer) writePending() { + if w.err != nil { + return + } + if w.pending { + w.fillHeader(true) + w.pending = false + } + _, w.err = w.w.Write(w.buf[w.written:w.j]) + w.written = w.j +} + +// Close finishes the current journal and closes the writer. +func (w *Writer) Close() error { + w.seq++ + w.writePending() + if w.err != nil { + return w.err + } + w.err = errors.New("leveldb/journal: closed Writer") + return nil +} + +// Flush finishes the current journal, writes to the underlying writer, and +// flushes it if that writer implements interface{ Flush() error }. +func (w *Writer) Flush() error { + w.seq++ + w.writePending() + if w.err != nil { + return w.err + } + if w.f != nil { + w.err = w.f.Flush() + return w.err + } + return nil +} + +// Reset resets the journal writer, allows reuse of the journal writer. Reset +// will also closes the journal writer if not already. +func (w *Writer) Reset(writer io.Writer) (err error) { + w.seq++ + if w.err == nil { + w.writePending() + err = w.err + } + w.w = writer + w.f, _ = writer.(flusher) + w.i = 0 + w.j = 0 + w.written = 0 + w.first = false + w.pending = false + w.err = nil + return +} + +// Next returns a writer for the next journal. The writer returned becomes stale +// after the next Close, Flush or Next call, and should no longer be used. +func (w *Writer) Next() (io.Writer, error) { + w.seq++ + if w.err != nil { + return nil, w.err + } + if w.pending { + w.fillHeader(true) + } + w.i = w.j + w.j = w.j + headerSize + // Check if there is room in the block for the header. + if w.j > blockSize { + // Fill in the rest of the block with zeroes. + for k := w.i; k < blockSize; k++ { + w.buf[k] = 0 + } + w.writeBlock() + if w.err != nil { + return nil, w.err + } + } + w.first = true + w.pending = true + return singleWriter{w, w.seq}, nil +} + +type singleWriter struct { + w *Writer + seq int +} + +func (x singleWriter) Write(p []byte) (int, error) { + w := x.w + if w.seq != x.seq { + return 0, errors.New("leveldb/journal: stale writer") + } + if w.err != nil { + return 0, w.err + } + n0 := len(p) + for len(p) > 0 { + // Write a block, if it is full. + if w.j == blockSize { + w.fillHeader(false) + w.writeBlock() + if w.err != nil { + return 0, w.err + } + w.first = false + } + // Copy bytes into the buffer. + n := copy(w.buf[w.j:], p) + w.j += n + p = p[n:] + } + return n0, nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go new file mode 100644 index 0000000000..0fcf22599f --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/journal/journal_test.go @@ -0,0 +1,818 @@ +// Copyright 2011 The LevelDB-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Taken from: https://code.google.com/p/leveldb-go/source/browse/leveldb/record/record_test.go?r=df1fa28f7f3be6c3935548169002309c12967135 +// License, authors and contributors informations can be found at bellow URLs respectively: +// https://code.google.com/p/leveldb-go/source/browse/LICENSE +// https://code.google.com/p/leveldb-go/source/browse/AUTHORS +// https://code.google.com/p/leveldb-go/source/browse/CONTRIBUTORS + +package journal + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "math/rand" + "strings" + "testing" +) + +type dropper struct { + t *testing.T +} + +func (d dropper) Drop(err error) { + d.t.Log(err) +} + +func short(s string) string { + if len(s) < 64 { + return s + } + return fmt.Sprintf("%s...(skipping %d bytes)...%s", s[:20], len(s)-40, s[len(s)-20:]) +} + +// big returns a string of length n, composed of repetitions of partial. +func big(partial string, n int) string { + return strings.Repeat(partial, n/len(partial)+1)[:n] +} + +func TestEmpty(t *testing.T) { + buf := new(bytes.Buffer) + r := NewReader(buf, dropper{t}, true, true) + if _, err := r.Next(); err != io.EOF { + t.Fatalf("got %v, want %v", err, io.EOF) + } +} + +func testGenerator(t *testing.T, reset func(), gen func() (string, bool)) { + buf := new(bytes.Buffer) + + reset() + w := NewWriter(buf) + for { + s, ok := gen() + if !ok { + break + } + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write([]byte(s)); err != nil { + t.Fatal(err) + } + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + reset() + r := NewReader(buf, dropper{t}, true, true) + for { + s, ok := gen() + if !ok { + break + } + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + x, err := ioutil.ReadAll(rr) + if err != nil { + t.Fatal(err) + } + if string(x) != s { + t.Fatalf("got %q, want %q", short(string(x)), short(s)) + } + } + if _, err := r.Next(); err != io.EOF { + t.Fatalf("got %v, want %v", err, io.EOF) + } +} + +func testLiterals(t *testing.T, s []string) { + var i int + reset := func() { + i = 0 + } + gen := func() (string, bool) { + if i == len(s) { + return "", false + } + i++ + return s[i-1], true + } + testGenerator(t, reset, gen) +} + +func TestMany(t *testing.T) { + const n = 1e5 + var i int + reset := func() { + i = 0 + } + gen := func() (string, bool) { + if i == n { + return "", false + } + i++ + return fmt.Sprintf("%d.", i-1), true + } + testGenerator(t, reset, gen) +} + +func TestRandom(t *testing.T) { + const n = 1e2 + var ( + i int + r *rand.Rand + ) + reset := func() { + i, r = 0, rand.New(rand.NewSource(0)) + } + gen := func() (string, bool) { + if i == n { + return "", false + } + i++ + return strings.Repeat(string(uint8(i)), r.Intn(2*blockSize+16)), true + } + testGenerator(t, reset, gen) +} + +func TestBasic(t *testing.T) { + testLiterals(t, []string{ + strings.Repeat("a", 1000), + strings.Repeat("b", 97270), + strings.Repeat("c", 8000), + }) +} + +func TestBoundary(t *testing.T) { + for i := blockSize - 16; i < blockSize+16; i++ { + s0 := big("abcd", i) + for j := blockSize - 16; j < blockSize+16; j++ { + s1 := big("ABCDE", j) + testLiterals(t, []string{s0, s1}) + testLiterals(t, []string{s0, "", s1}) + testLiterals(t, []string{s0, "x", s1}) + } + } +} + +func TestFlush(t *testing.T) { + buf := new(bytes.Buffer) + w := NewWriter(buf) + // Write a couple of records. Everything should still be held + // in the record.Writer buffer, so that buf.Len should be 0. + w0, _ := w.Next() + w0.Write([]byte("0")) + w1, _ := w.Next() + w1.Write([]byte("11")) + if got, want := buf.Len(), 0; got != want { + t.Fatalf("buffer length #0: got %d want %d", got, want) + } + // Flush the record.Writer buffer, which should yield 17 bytes. + // 17 = 2*7 + 1 + 2, which is two headers and 1 + 2 payload bytes. + if err := w.Flush(); err != nil { + t.Fatal(err) + } + if got, want := buf.Len(), 17; got != want { + t.Fatalf("buffer length #1: got %d want %d", got, want) + } + // Do another write, one that isn't large enough to complete the block. + // The write should not have flowed through to buf. + w2, _ := w.Next() + w2.Write(bytes.Repeat([]byte("2"), 10000)) + if got, want := buf.Len(), 17; got != want { + t.Fatalf("buffer length #2: got %d want %d", got, want) + } + // Flushing should get us up to 10024 bytes written. + // 10024 = 17 + 7 + 10000. + if err := w.Flush(); err != nil { + t.Fatal(err) + } + if got, want := buf.Len(), 10024; got != want { + t.Fatalf("buffer length #3: got %d want %d", got, want) + } + // Do a bigger write, one that completes the current block. + // We should now have 32768 bytes (a complete block), without + // an explicit flush. + w3, _ := w.Next() + w3.Write(bytes.Repeat([]byte("3"), 40000)) + if got, want := buf.Len(), 32768; got != want { + t.Fatalf("buffer length #4: got %d want %d", got, want) + } + // Flushing should get us up to 50038 bytes written. + // 50038 = 10024 + 2*7 + 40000. There are two headers because + // the one record was split into two chunks. + if err := w.Flush(); err != nil { + t.Fatal(err) + } + if got, want := buf.Len(), 50038; got != want { + t.Fatalf("buffer length #5: got %d want %d", got, want) + } + // Check that reading those records give the right lengths. + r := NewReader(buf, dropper{t}, true, true) + wants := []int64{1, 2, 10000, 40000} + for i, want := range wants { + rr, _ := r.Next() + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #%d: %v", i, err) + } + if n != want { + t.Fatalf("read #%d: got %d bytes want %d", i, n, want) + } + } +} + +func TestNonExhaustiveRead(t *testing.T) { + const n = 100 + buf := new(bytes.Buffer) + p := make([]byte, 10) + rnd := rand.New(rand.NewSource(1)) + + w := NewWriter(buf) + for i := 0; i < n; i++ { + length := len(p) + rnd.Intn(3*blockSize) + s := string(uint8(i)) + "123456789abcdefgh" + ww, _ := w.Next() + ww.Write([]byte(big(s, length))) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + r := NewReader(buf, dropper{t}, true, true) + for i := 0; i < n; i++ { + rr, _ := r.Next() + _, err := io.ReadFull(rr, p) + if err != nil { + t.Fatal(err) + } + want := string(uint8(i)) + "123456789" + if got := string(p); got != want { + t.Fatalf("read #%d: got %q want %q", i, got, want) + } + } +} + +func TestStaleReader(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + w0, err := w.Next() + if err != nil { + t.Fatal(err) + } + w0.Write([]byte("0")) + w1, err := w.Next() + if err != nil { + t.Fatal(err) + } + w1.Write([]byte("11")) + if err := w.Close(); err != nil { + t.Fatal(err) + } + + r := NewReader(buf, dropper{t}, true, true) + r0, err := r.Next() + if err != nil { + t.Fatal(err) + } + r1, err := r.Next() + if err != nil { + t.Fatal(err) + } + p := make([]byte, 1) + if _, err := r0.Read(p); err == nil || !strings.Contains(err.Error(), "stale") { + t.Fatalf("stale read #0: unexpected error: %v", err) + } + if _, err := r1.Read(p); err != nil { + t.Fatalf("fresh read #1: got %v want nil error", err) + } + if p[0] != '1' { + t.Fatalf("fresh read #1: byte contents: got '%c' want '1'", p[0]) + } +} + +func TestStaleWriter(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + w0, err := w.Next() + if err != nil { + t.Fatal(err) + } + w1, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := w0.Write([]byte("0")); err == nil || !strings.Contains(err.Error(), "stale") { + t.Fatalf("stale write #0: unexpected error: %v", err) + } + if _, err := w1.Write([]byte("11")); err != nil { + t.Fatalf("fresh write #1: got %v want nil error", err) + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + if _, err := w1.Write([]byte("0")); err == nil || !strings.Contains(err.Error(), "stale") { + t.Fatalf("stale write #1: unexpected error: %v", err) + } +} + +func TestCorrupt_MissingLastBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-1024)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + // Cut the last block. + b := buf.Bytes()[:blockSize] + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read. + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if n != blockSize-1024 { + t.Fatalf("read #0: got %d bytes want %d", n, blockSize-1024) + } + + // Second read. + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #1: unexpected error: %v", err) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_CorruptedFirstBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + // Fourth record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+2)); err != nil { + t.Fatalf("write #3: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting block #0. + for i := 0; i < 1024; i++ { + b[i] = '1' + } + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (third record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize-headerSize) + 1; n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (fourth record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #1: %v", err) + } + if want := int64(blockSize-headerSize) + 2; n != want { + t.Fatalf("read #1: got %d bytes want %d", n, want) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_CorruptedMiddleBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + // Fourth record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+2)); err != nil { + t.Fatalf("write #3: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting block #1. + for i := 0; i < 1024; i++ { + b[blockSize+i] = '1' + } + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (second record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #1: unexpected error: %v", err) + } + + // Third read (fourth record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #2: %v", err) + } + if want := int64(blockSize-headerSize) + 2; n != want { + t.Fatalf("read #2: got %d bytes want %d", n, want) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_CorruptedLastBlock(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + // Fourth record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+2)); err != nil { + t.Fatalf("write #3: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting block #3. + for i := len(b) - 1; i > len(b)-1024; i-- { + b[i] = '1' + } + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (second record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #1: %v", err) + } + if want := int64(blockSize - headerSize); n != want { + t.Fatalf("read #1: got %d bytes want %d", n, want) + } + + // Third read (third record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #2: %v", err) + } + if want := int64(blockSize-headerSize) + 1; n != want { + t.Fatalf("read #2: got %d bytes want %d", n, want) + } + + // Fourth read (fourth record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #3: unexpected error: %v", err) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_FirstChuckLengthOverflow(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting record #1. + x := blockSize + binary.LittleEndian.PutUint16(b[x+4:], 0xffff) + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (second record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != io.ErrUnexpectedEOF { + t.Fatalf("read #1: unexpected error: %v", err) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} + +func TestCorrupt_MiddleChuckLengthOverflow(t *testing.T) { + buf := new(bytes.Buffer) + + w := NewWriter(buf) + + // First record. + ww, err := w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize/2)); err != nil { + t.Fatalf("write #0: unexpected error: %v", err) + } + + // Second record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), blockSize-headerSize)); err != nil { + t.Fatalf("write #1: unexpected error: %v", err) + } + + // Third record. + ww, err = w.Next() + if err != nil { + t.Fatal(err) + } + if _, err := ww.Write(bytes.Repeat([]byte("0"), (blockSize-headerSize)+1)); err != nil { + t.Fatalf("write #2: unexpected error: %v", err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + b := buf.Bytes() + // Corrupting record #1. + x := blockSize/2 + headerSize + binary.LittleEndian.PutUint16(b[x+4:], 0xffff) + + r := NewReader(bytes.NewReader(b), dropper{t}, false, true) + + // First read (first record). + rr, err := r.Next() + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #0: %v", err) + } + if want := int64(blockSize / 2); n != want { + t.Fatalf("read #0: got %d bytes want %d", n, want) + } + + // Second read (third record). + rr, err = r.Next() + if err != nil { + t.Fatal(err) + } + n, err = io.Copy(ioutil.Discard, rr) + if err != nil { + t.Fatalf("read #1: %v", err) + } + if want := int64(blockSize-headerSize) + 1; n != want { + t.Fatalf("read #1: got %d bytes want %d", n, want) + } + + if _, err := r.Next(); err != io.EOF { + t.Fatalf("last next: unexpected error: %v", err) + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key.go new file mode 100644 index 0000000000..572ae8150c --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key.go @@ -0,0 +1,142 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "encoding/binary" + "fmt" + + "github.com/syndtr/goleveldb/leveldb/errors" +) + +type ErrIkeyCorrupted struct { + Ikey []byte + Reason string +} + +func (e *ErrIkeyCorrupted) Error() string { + return fmt.Sprintf("leveldb: iKey %q corrupted: %s", e.Ikey, e.Reason) +} + +func newErrIkeyCorrupted(ikey []byte, reason string) error { + return errors.NewErrCorrupted(nil, &ErrIkeyCorrupted{append([]byte{}, ikey...), reason}) +} + +type kType int + +func (kt kType) String() string { + switch kt { + case ktDel: + return "d" + case ktVal: + return "v" + } + return "x" +} + +// Value types encoded as the last component of internal keys. +// Don't modify; this value are saved to disk. +const ( + ktDel kType = iota + ktVal +) + +// ktSeek defines the kType that should be passed when constructing an +// internal key for seeking to a particular sequence number (since we +// sort sequence numbers in decreasing order and the value type is +// embedded as the low 8 bits in the sequence number in internal keys, +// we need to use the highest-numbered ValueType, not the lowest). +const ktSeek = ktVal + +const ( + // Maximum value possible for sequence number; the 8-bits are + // used by value type, so its can packed together in single + // 64-bit integer. + kMaxSeq uint64 = (uint64(1) << 56) - 1 + // Maximum value possible for packed sequence number and type. + kMaxNum uint64 = (kMaxSeq << 8) | uint64(ktSeek) +) + +// Maximum number encoded in bytes. +var kMaxNumBytes = make([]byte, 8) + +func init() { + binary.LittleEndian.PutUint64(kMaxNumBytes, kMaxNum) +} + +type iKey []byte + +func newIkey(ukey []byte, seq uint64, kt kType) iKey { + if seq > kMaxSeq { + panic("leveldb: invalid sequence number") + } else if kt > ktVal { + panic("leveldb: invalid type") + } + + ik := make(iKey, len(ukey)+8) + copy(ik, ukey) + binary.LittleEndian.PutUint64(ik[len(ukey):], (seq<<8)|uint64(kt)) + return ik +} + +func parseIkey(ik []byte) (ukey []byte, seq uint64, kt kType, err error) { + if len(ik) < 8 { + return nil, 0, 0, newErrIkeyCorrupted(ik, "invalid length") + } + num := binary.LittleEndian.Uint64(ik[len(ik)-8:]) + seq, kt = uint64(num>>8), kType(num&0xff) + if kt > ktVal { + return nil, 0, 0, newErrIkeyCorrupted(ik, "invalid type") + } + ukey = ik[:len(ik)-8] + return +} + +func validIkey(ik []byte) bool { + _, _, _, err := parseIkey(ik) + return err == nil +} + +func (ik iKey) assert() { + if ik == nil { + panic("leveldb: nil iKey") + } + if len(ik) < 8 { + panic(fmt.Sprintf("leveldb: iKey %q, len=%d: invalid length", []byte(ik), len(ik))) + } +} + +func (ik iKey) ukey() []byte { + ik.assert() + return ik[:len(ik)-8] +} + +func (ik iKey) num() uint64 { + ik.assert() + return binary.LittleEndian.Uint64(ik[len(ik)-8:]) +} + +func (ik iKey) parseNum() (seq uint64, kt kType) { + num := ik.num() + seq, kt = uint64(num>>8), kType(num&0xff) + if kt > ktVal { + panic(fmt.Sprintf("leveldb: iKey %q, len=%d: invalid type %#x", []byte(ik), len(ik), kt)) + } + return +} + +func (ik iKey) String() string { + if ik == nil { + return "" + } + + if ukey, seq, kt, err := parseIkey(ik); err == nil { + return fmt.Sprintf("%s,%s%d", shorten(string(ukey)), kt, seq) + } else { + return "" + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key_test.go new file mode 100644 index 0000000000..30eadf7847 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/key_test.go @@ -0,0 +1,133 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "testing" + + "github.com/syndtr/goleveldb/leveldb/comparer" +) + +var defaultIComparer = &iComparer{comparer.DefaultComparer} + +func ikey(key string, seq uint64, kt kType) iKey { + return newIkey([]byte(key), uint64(seq), kt) +} + +func shortSep(a, b []byte) []byte { + dst := make([]byte, len(a)) + dst = defaultIComparer.Separator(dst[:0], a, b) + if dst == nil { + return a + } + return dst +} + +func shortSuccessor(b []byte) []byte { + dst := make([]byte, len(b)) + dst = defaultIComparer.Successor(dst[:0], b) + if dst == nil { + return b + } + return dst +} + +func testSingleKey(t *testing.T, key string, seq uint64, kt kType) { + ik := ikey(key, seq, kt) + + if !bytes.Equal(ik.ukey(), []byte(key)) { + t.Errorf("user key does not equal, got %v, want %v", string(ik.ukey()), key) + } + + rseq, rt := ik.parseNum() + if rseq != seq { + t.Errorf("seq number does not equal, got %v, want %v", rseq, seq) + } + if rt != kt { + t.Errorf("type does not equal, got %v, want %v", rt, kt) + } + + if rukey, rseq, rt, kerr := parseIkey(ik); kerr == nil { + if !bytes.Equal(rukey, []byte(key)) { + t.Errorf("user key does not equal, got %v, want %v", string(ik.ukey()), key) + } + if rseq != seq { + t.Errorf("seq number does not equal, got %v, want %v", rseq, seq) + } + if rt != kt { + t.Errorf("type does not equal, got %v, want %v", rt, kt) + } + } else { + t.Errorf("key error: %v", kerr) + } +} + +func TestIkey_EncodeDecode(t *testing.T) { + keys := []string{"", "k", "hello", "longggggggggggggggggggggg"} + seqs := []uint64{ + 1, 2, 3, + (1 << 8) - 1, 1 << 8, (1 << 8) + 1, + (1 << 16) - 1, 1 << 16, (1 << 16) + 1, + (1 << 32) - 1, 1 << 32, (1 << 32) + 1, + } + for _, key := range keys { + for _, seq := range seqs { + testSingleKey(t, key, seq, ktVal) + testSingleKey(t, "hello", 1, ktDel) + } + } +} + +func assertBytes(t *testing.T, want, got []byte) { + if !bytes.Equal(got, want) { + t.Errorf("assert failed, got %v, want %v", got, want) + } +} + +func TestIkeyShortSeparator(t *testing.T) { + // When user keys are same + assertBytes(t, ikey("foo", 100, ktVal), + shortSep(ikey("foo", 100, ktVal), + ikey("foo", 99, ktVal))) + assertBytes(t, ikey("foo", 100, ktVal), + shortSep(ikey("foo", 100, ktVal), + ikey("foo", 101, ktVal))) + assertBytes(t, ikey("foo", 100, ktVal), + shortSep(ikey("foo", 100, ktVal), + ikey("foo", 100, ktVal))) + assertBytes(t, ikey("foo", 100, ktVal), + shortSep(ikey("foo", 100, ktVal), + ikey("foo", 100, ktDel))) + + // When user keys are misordered + assertBytes(t, ikey("foo", 100, ktVal), + shortSep(ikey("foo", 100, ktVal), + ikey("bar", 99, ktVal))) + + // When user keys are different, but correctly ordered + assertBytes(t, ikey("g", uint64(kMaxSeq), ktSeek), + shortSep(ikey("foo", 100, ktVal), + ikey("hello", 200, ktVal))) + + // When start user key is prefix of limit user key + assertBytes(t, ikey("foo", 100, ktVal), + shortSep(ikey("foo", 100, ktVal), + ikey("foobar", 200, ktVal))) + + // When limit user key is prefix of start user key + assertBytes(t, ikey("foobar", 100, ktVal), + shortSep(ikey("foobar", 100, ktVal), + ikey("foo", 200, ktVal))) +} + +func TestIkeyShortestSuccessor(t *testing.T) { + assertBytes(t, ikey("g", uint64(kMaxSeq), ktSeek), + shortSuccessor(ikey("foo", 100, ktVal))) + assertBytes(t, ikey("\xff\xff", 100, ktVal), + shortSuccessor(ikey("\xff\xff", 100, ktVal))) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go new file mode 100644 index 0000000000..fefa007a70 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/leveldb_suite_test.go @@ -0,0 +1,11 @@ +package leveldb + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestLevelDB(t *testing.T) { + testutil.RunSuite(t, "LevelDB Suite") +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go new file mode 100644 index 0000000000..b05084caa6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/bench_test.go @@ -0,0 +1,75 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package memdb + +import ( + "encoding/binary" + "math/rand" + "testing" + + "github.com/syndtr/goleveldb/leveldb/comparer" +) + +func BenchmarkPut(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(i)) + } + + b.ResetTimer() + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } +} + +func BenchmarkPutRandom(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(rand.Int())) + } + + b.ResetTimer() + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } +} + +func BenchmarkGet(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(i)) + } + + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } + + b.ResetTimer() + for i := range buf { + p.Get(buf[i][:]) + } +} + +func BenchmarkGetRandom(b *testing.B) { + buf := make([][4]byte, b.N) + for i := range buf { + binary.LittleEndian.PutUint32(buf[i][:], uint32(i)) + } + + p := New(comparer.DefaultComparer, 0) + for i := range buf { + p.Put(buf[i][:], nil) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + p.Get(buf[rand.Int()%b.N][:]) + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go new file mode 100644 index 0000000000..e5398873b7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go @@ -0,0 +1,468 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package memdb provides in-memory key/value database implementation. +package memdb + +import ( + "math/rand" + "sync" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + ErrNotFound = errors.ErrNotFound + ErrIterReleased = errors.New("leveldb/memdb: iterator released") +) + +const tMaxHeight = 12 + +type dbIter struct { + util.BasicReleaser + p *DB + slice *util.Range + node int + forward bool + key, value []byte + err error +} + +func (i *dbIter) fill(checkStart, checkLimit bool) bool { + if i.node != 0 { + n := i.p.nodeData[i.node] + m := n + i.p.nodeData[i.node+nKey] + i.key = i.p.kvData[n:m] + if i.slice != nil { + switch { + case checkLimit && i.slice.Limit != nil && i.p.cmp.Compare(i.key, i.slice.Limit) >= 0: + fallthrough + case checkStart && i.slice.Start != nil && i.p.cmp.Compare(i.key, i.slice.Start) < 0: + i.node = 0 + goto bail + } + } + i.value = i.p.kvData[m : m+i.p.nodeData[i.node+nVal]] + return true + } +bail: + i.key = nil + i.value = nil + return false +} + +func (i *dbIter) Valid() bool { + return i.node != 0 +} + +func (i *dbIter) First() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.forward = true + i.p.mu.RLock() + defer i.p.mu.RUnlock() + if i.slice != nil && i.slice.Start != nil { + i.node, _ = i.p.findGE(i.slice.Start, false) + } else { + i.node = i.p.nodeData[nNext] + } + return i.fill(false, true) +} + +func (i *dbIter) Last() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.forward = false + i.p.mu.RLock() + defer i.p.mu.RUnlock() + if i.slice != nil && i.slice.Limit != nil { + i.node = i.p.findLT(i.slice.Limit) + } else { + i.node = i.p.findLast() + } + return i.fill(true, false) +} + +func (i *dbIter) Seek(key []byte) bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + i.forward = true + i.p.mu.RLock() + defer i.p.mu.RUnlock() + if i.slice != nil && i.slice.Start != nil && i.p.cmp.Compare(key, i.slice.Start) < 0 { + key = i.slice.Start + } + i.node, _ = i.p.findGE(key, false) + return i.fill(false, true) +} + +func (i *dbIter) Next() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + if i.node == 0 { + if !i.forward { + return i.First() + } + return false + } + i.forward = true + i.p.mu.RLock() + defer i.p.mu.RUnlock() + i.node = i.p.nodeData[i.node+nNext] + return i.fill(false, true) +} + +func (i *dbIter) Prev() bool { + if i.Released() { + i.err = ErrIterReleased + return false + } + + if i.node == 0 { + if i.forward { + return i.Last() + } + return false + } + i.forward = false + i.p.mu.RLock() + defer i.p.mu.RUnlock() + i.node = i.p.findLT(i.key) + return i.fill(true, false) +} + +func (i *dbIter) Key() []byte { + return i.key +} + +func (i *dbIter) Value() []byte { + return i.value +} + +func (i *dbIter) Error() error { return i.err } + +func (i *dbIter) Release() { + if !i.Released() { + i.p = nil + i.node = 0 + i.key = nil + i.value = nil + i.BasicReleaser.Release() + } +} + +const ( + nKV = iota + nKey + nVal + nHeight + nNext +) + +// DB is an in-memory key/value database. +type DB struct { + cmp comparer.BasicComparer + rnd *rand.Rand + + mu sync.RWMutex + kvData []byte + // Node data: + // [0] : KV offset + // [1] : Key length + // [2] : Value length + // [3] : Height + // [3..height] : Next nodes + nodeData []int + prevNode [tMaxHeight]int + maxHeight int + n int + kvSize int +} + +func (p *DB) randHeight() (h int) { + const branching = 4 + h = 1 + for h < tMaxHeight && p.rnd.Int()%branching == 0 { + h++ + } + return +} + +func (p *DB) findGE(key []byte, prev bool) (int, bool) { + node := 0 + h := p.maxHeight - 1 + for { + next := p.nodeData[node+nNext+h] + cmp := 1 + if next != 0 { + o := p.nodeData[next] + cmp = p.cmp.Compare(p.kvData[o:o+p.nodeData[next+nKey]], key) + } + if cmp < 0 { + // Keep searching in this list + node = next + } else { + if prev { + p.prevNode[h] = node + } else if cmp == 0 { + return next, true + } + if h == 0 { + return next, cmp == 0 + } + h-- + } + } +} + +func (p *DB) findLT(key []byte) int { + node := 0 + h := p.maxHeight - 1 + for { + next := p.nodeData[node+nNext+h] + o := p.nodeData[next] + if next == 0 || p.cmp.Compare(p.kvData[o:o+p.nodeData[next+nKey]], key) >= 0 { + if h == 0 { + break + } + h-- + } else { + node = next + } + } + return node +} + +func (p *DB) findLast() int { + node := 0 + h := p.maxHeight - 1 + for { + next := p.nodeData[node+nNext+h] + if next == 0 { + if h == 0 { + break + } + h-- + } else { + node = next + } + } + return node +} + +// Put sets the value for the given key. It overwrites any previous value +// for that key; a DB is not a multi-map. +// +// It is safe to modify the contents of the arguments after Put returns. +func (p *DB) Put(key []byte, value []byte) error { + p.mu.Lock() + defer p.mu.Unlock() + + if node, exact := p.findGE(key, true); exact { + kvOffset := len(p.kvData) + p.kvData = append(p.kvData, key...) + p.kvData = append(p.kvData, value...) + p.nodeData[node] = kvOffset + m := p.nodeData[node+nVal] + p.nodeData[node+nVal] = len(value) + p.kvSize += len(value) - m + return nil + } + + h := p.randHeight() + if h > p.maxHeight { + for i := p.maxHeight; i < h; i++ { + p.prevNode[i] = 0 + } + p.maxHeight = h + } + + kvOffset := len(p.kvData) + p.kvData = append(p.kvData, key...) + p.kvData = append(p.kvData, value...) + // Node + node := len(p.nodeData) + p.nodeData = append(p.nodeData, kvOffset, len(key), len(value), h) + for i, n := range p.prevNode[:h] { + m := n + 4 + i + p.nodeData = append(p.nodeData, p.nodeData[m]) + p.nodeData[m] = node + } + + p.kvSize += len(key) + len(value) + p.n++ + return nil +} + +// Delete deletes the value for the given key. It returns ErrNotFound if +// the DB does not contain the key. +// +// It is safe to modify the contents of the arguments after Delete returns. +func (p *DB) Delete(key []byte) error { + p.mu.Lock() + defer p.mu.Unlock() + + node, exact := p.findGE(key, true) + if !exact { + return ErrNotFound + } + + h := p.nodeData[node+nHeight] + for i, n := range p.prevNode[:h] { + m := n + 4 + i + p.nodeData[m] = p.nodeData[p.nodeData[m]+nNext+i] + } + + p.kvSize -= p.nodeData[node+nKey] + p.nodeData[node+nVal] + p.n-- + return nil +} + +// Contains returns true if the given key are in the DB. +// +// It is safe to modify the contents of the arguments after Contains returns. +func (p *DB) Contains(key []byte) bool { + p.mu.RLock() + _, exact := p.findGE(key, false) + p.mu.RUnlock() + return exact +} + +// Get gets the value for the given key. It returns error.ErrNotFound if the +// DB does not contain the key. +// +// The caller should not modify the contents of the returned slice, but +// it is safe to modify the contents of the argument after Get returns. +func (p *DB) Get(key []byte) (value []byte, err error) { + p.mu.RLock() + if node, exact := p.findGE(key, false); exact { + o := p.nodeData[node] + p.nodeData[node+nKey] + value = p.kvData[o : o+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +// Find finds key/value pair whose key is greater than or equal to the +// given key. It returns ErrNotFound if the table doesn't contain +// such pair. +// +// The caller should not modify the contents of the returned slice, but +// it is safe to modify the contents of the argument after Find returns. +func (p *DB) Find(key []byte) (rkey, value []byte, err error) { + p.mu.RLock() + if node, _ := p.findGE(key, false); node != 0 { + n := p.nodeData[node] + m := n + p.nodeData[node+nKey] + rkey = p.kvData[n:m] + value = p.kvData[m : m+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +// NewIterator returns an iterator of the DB. +// The returned iterator is not goroutine-safe, but it is safe to use +// multiple iterators concurrently, with each in a dedicated goroutine. +// It is also safe to use an iterator concurrently with modifying its +// underlying DB. However, the resultant key/value pairs are not guaranteed +// to be a consistent snapshot of the DB at a particular point in time. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// DB. And a nil Range.Limit is treated as a key after all keys in +// the DB. +// +// The iterator must be released after use, by calling Release method. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (p *DB) NewIterator(slice *util.Range) iterator.Iterator { + return &dbIter{p: p, slice: slice} +} + +// Capacity returns keys/values buffer capacity. +func (p *DB) Capacity() int { + p.mu.RLock() + defer p.mu.RUnlock() + return cap(p.kvData) +} + +// Size returns sum of keys and values length. Note that deleted +// key/value will not be accouted for, but it will still consume +// the buffer, since the buffer is append only. +func (p *DB) Size() int { + p.mu.RLock() + defer p.mu.RUnlock() + return p.kvSize +} + +// Free returns keys/values free buffer before need to grow. +func (p *DB) Free() int { + p.mu.RLock() + defer p.mu.RUnlock() + return cap(p.kvData) - len(p.kvData) +} + +// Len returns the number of entries in the DB. +func (p *DB) Len() int { + p.mu.RLock() + defer p.mu.RUnlock() + return p.n +} + +// Reset resets the DB to initial empty state. Allows reuse the buffer. +func (p *DB) Reset() { + p.rnd = rand.New(rand.NewSource(0xdeadbeef)) + p.maxHeight = 1 + p.n = 0 + p.kvSize = 0 + p.kvData = p.kvData[:0] + p.nodeData = p.nodeData[:4+tMaxHeight] + p.nodeData[nKV] = 0 + p.nodeData[nKey] = 0 + p.nodeData[nVal] = 0 + p.nodeData[nHeight] = tMaxHeight + for n := 0; n < tMaxHeight; n++ { + p.nodeData[4+n] = 0 + p.prevNode[n] = 0 + } +} + +// New creates a new initalized in-memory key/value DB. The capacity +// is the initial key/value buffer capacity. The capacity is advisory, +// not enforced. +// +// The returned DB instance is goroutine-safe. +func New(cmp comparer.BasicComparer, capacity int) *DB { + p := &DB{ + cmp: cmp, + rnd: rand.New(rand.NewSource(0xdeadbeef)), + maxHeight: 1, + kvData: make([]byte, 0, capacity), + nodeData: make([]int, 4+tMaxHeight), + } + p.nodeData[nHeight] = tMaxHeight + return p +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go new file mode 100644 index 0000000000..18c304b7f1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_suite_test.go @@ -0,0 +1,11 @@ +package memdb + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestMemDB(t *testing.T) { + testutil.RunSuite(t, "MemDB Suite") +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go new file mode 100644 index 0000000000..5dd6dbc7b7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/memdb/memdb_test.go @@ -0,0 +1,135 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package memdb + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func (p *DB) TestFindLT(key []byte) (rkey, value []byte, err error) { + p.mu.RLock() + if node := p.findLT(key); node != 0 { + n := p.nodeData[node] + m := n + p.nodeData[node+nKey] + rkey = p.kvData[n:m] + value = p.kvData[m : m+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +func (p *DB) TestFindLast() (rkey, value []byte, err error) { + p.mu.RLock() + if node := p.findLast(); node != 0 { + n := p.nodeData[node] + m := n + p.nodeData[node+nKey] + rkey = p.kvData[n:m] + value = p.kvData[m : m+p.nodeData[node+nVal]] + } else { + err = ErrNotFound + } + p.mu.RUnlock() + return +} + +func (p *DB) TestPut(key []byte, value []byte) error { + p.Put(key, value) + return nil +} + +func (p *DB) TestDelete(key []byte) error { + p.Delete(key) + return nil +} + +func (p *DB) TestFind(key []byte) (rkey, rvalue []byte, err error) { + return p.Find(key) +} + +func (p *DB) TestGet(key []byte) (value []byte, err error) { + return p.Get(key) +} + +func (p *DB) TestNewIterator(slice *util.Range) iterator.Iterator { + return p.NewIterator(slice) +} + +var _ = testutil.Defer(func() { + Describe("Memdb", func() { + Describe("write test", func() { + It("should do write correctly", func() { + db := New(comparer.DefaultComparer, 0) + t := testutil.DBTesting{ + DB: db, + Deleted: testutil.KeyValue_Generate(nil, 1000, 1, 30, 5, 5).Clone(), + PostFn: func(t *testutil.DBTesting) { + Expect(db.Len()).Should(Equal(t.Present.Len())) + Expect(db.Size()).Should(Equal(t.Present.Size())) + switch t.Act { + case testutil.DBPut, testutil.DBOverwrite: + Expect(db.Contains(t.ActKey)).Should(BeTrue()) + default: + Expect(db.Contains(t.ActKey)).Should(BeFalse()) + } + }, + } + testutil.DoDBTesting(&t) + }) + }) + + Describe("read test", func() { + testutil.AllKeyValueTesting(nil, func(kv testutil.KeyValue) testutil.DB { + // Building the DB. + db := New(comparer.DefaultComparer, 0) + kv.IterateShuffled(nil, func(i int, key, value []byte) { + db.Put(key, value) + }) + + if kv.Len() > 1 { + It("Should find correct keys with findLT", func() { + testutil.ShuffledIndex(nil, kv.Len()-1, 1, func(i int) { + key_, key, _ := kv.IndexInexact(i + 1) + expectedKey, expectedValue := kv.Index(i) + + // Using key that exist. + rkey, rvalue, err := db.TestFindLT(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q -> %q", key, expectedKey) + Expect(rkey).Should(Equal(expectedKey), "Key") + Expect(rvalue).Should(Equal(expectedValue), "Value for key %q -> %q", key, expectedKey) + + // Using key that doesn't exist. + rkey, rvalue, err = db.TestFindLT(key_) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q (%q) -> %q", key_, key, expectedKey) + Expect(rkey).Should(Equal(expectedKey)) + Expect(rvalue).Should(Equal(expectedValue), "Value for key %q (%q) -> %q", key_, key, expectedKey) + }) + }) + } + + if kv.Len() > 0 { + It("Should find last key with findLast", func() { + key, value := kv.Index(kv.Len() - 1) + rkey, rvalue, err := db.TestFindLast() + Expect(err).ShouldNot(HaveOccurred()) + Expect(rkey).Should(Equal(key)) + Expect(rvalue).Should(Equal(value)) + }) + } + + return db + }, nil, nil) + }) + }) +}) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/opt/options.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/opt/options.go new file mode 100644 index 0000000000..86828f47b3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/opt/options.go @@ -0,0 +1,624 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package opt provides sets of options used by LevelDB. +package opt + +import ( + "github.com/syndtr/goleveldb/leveldb/cache" + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/filter" + "math" +) + +const ( + KiB = 1024 + MiB = KiB * 1024 + GiB = MiB * 1024 +) + +var ( + DefaultBlockCacher = LRUCacher + DefaultBlockCacheCapacity = 8 * MiB + DefaultBlockRestartInterval = 16 + DefaultBlockSize = 4 * KiB + DefaultCompactionExpandLimitFactor = 25 + DefaultCompactionGPOverlapsFactor = 10 + DefaultCompactionL0Trigger = 4 + DefaultCompactionSourceLimitFactor = 1 + DefaultCompactionTableSize = 2 * MiB + DefaultCompactionTableSizeMultiplier = 1.0 + DefaultCompactionTotalSize = 10 * MiB + DefaultCompactionTotalSizeMultiplier = 10.0 + DefaultCompressionType = SnappyCompression + DefaultOpenFilesCacher = LRUCacher + DefaultOpenFilesCacheCapacity = 500 + DefaultMaxMemCompationLevel = 2 + DefaultNumLevel = 7 + DefaultWriteBuffer = 4 * MiB + DefaultWriteL0PauseTrigger = 12 + DefaultWriteL0SlowdownTrigger = 8 +) + +// Cacher is a caching algorithm. +type Cacher interface { + New(capacity int) cache.Cacher +} + +type CacherFunc struct { + NewFunc func(capacity int) cache.Cacher +} + +func (f *CacherFunc) New(capacity int) cache.Cacher { + if f.NewFunc != nil { + return f.NewFunc(capacity) + } + return nil +} + +func noCacher(int) cache.Cacher { return nil } + +var ( + // LRUCacher is the LRU-cache algorithm. + LRUCacher = &CacherFunc{cache.NewLRU} + + // NoCacher is the value to disable caching algorithm. + NoCacher = &CacherFunc{} +) + +// Compression is the 'sorted table' block compression algorithm to use. +type Compression uint + +func (c Compression) String() string { + switch c { + case DefaultCompression: + return "default" + case NoCompression: + return "none" + case SnappyCompression: + return "snappy" + } + return "invalid" +} + +const ( + DefaultCompression Compression = iota + NoCompression + SnappyCompression + nCompression +) + +// Strict is the DB 'strict level'. +type Strict uint + +const ( + // If present then a corrupted or invalid chunk or block in manifest + // journal will cause an error instead of being dropped. + // This will prevent database with corrupted manifest to be opened. + StrictManifest Strict = 1 << iota + + // If present then journal chunk checksum will be verified. + StrictJournalChecksum + + // If present then a corrupted or invalid chunk or block in journal + // will cause an error instead of being dropped. + // This will prevent database with corrupted journal to be opened. + StrictJournal + + // If present then 'sorted table' block checksum will be verified. + // This has effect on both 'read operation' and compaction. + StrictBlockChecksum + + // If present then a corrupted 'sorted table' will fails compaction. + // The database will enter read-only mode. + StrictCompaction + + // If present then a corrupted 'sorted table' will halts 'read operation'. + StrictReader + + // If present then leveldb.Recover will drop corrupted 'sorted table'. + StrictRecovery + + // This only applicable for ReadOptions, if present then this ReadOptions + // 'strict level' will override global ones. + StrictOverride + + // StrictAll enables all strict flags. + StrictAll = StrictManifest | StrictJournalChecksum | StrictJournal | StrictBlockChecksum | StrictCompaction | StrictReader | StrictRecovery + + // DefaultStrict is the default strict flags. Specify any strict flags + // will override default strict flags as whole (i.e. not OR'ed). + DefaultStrict = StrictJournalChecksum | StrictBlockChecksum | StrictCompaction | StrictReader + + // NoStrict disables all strict flags. Override default strict flags. + NoStrict = ^StrictAll +) + +// Options holds the optional parameters for the DB at large. +type Options struct { + // AltFilters defines one or more 'alternative filters'. + // 'alternative filters' will be used during reads if a filter block + // does not match with the 'effective filter'. + // + // The default value is nil + AltFilters []filter.Filter + + // BlockCacher provides cache algorithm for LevelDB 'sorted table' block caching. + // Specify NoCacher to disable caching algorithm. + // + // The default value is LRUCacher. + BlockCacher Cacher + + // BlockCacheCapacity defines the capacity of the 'sorted table' block caching. + // Use -1 for zero, this has same effect with specifying NoCacher to BlockCacher. + // + // The default value is 8MiB. + BlockCacheCapacity int + + // BlockRestartInterval is the number of keys between restart points for + // delta encoding of keys. + // + // The default value is 16. + BlockRestartInterval int + + // BlockSize is the minimum uncompressed size in bytes of each 'sorted table' + // block. + // + // The default value is 4KiB. + BlockSize int + + // CompactionExpandLimitFactor limits compaction size after expanded. + // This will be multiplied by table size limit at compaction target level. + // + // The default value is 25. + CompactionExpandLimitFactor int + + // CompactionGPOverlapsFactor limits overlaps in grandparent (Level + 2) that a + // single 'sorted table' generates. + // This will be multiplied by table size limit at grandparent level. + // + // The default value is 10. + CompactionGPOverlapsFactor int + + // CompactionL0Trigger defines number of 'sorted table' at level-0 that will + // trigger compaction. + // + // The default value is 4. + CompactionL0Trigger int + + // CompactionSourceLimitFactor limits compaction source size. This doesn't apply to + // level-0. + // This will be multiplied by table size limit at compaction target level. + // + // The default value is 1. + CompactionSourceLimitFactor int + + // CompactionTableSize limits size of 'sorted table' that compaction generates. + // The limits for each level will be calculated as: + // CompactionTableSize * (CompactionTableSizeMultiplier ^ Level) + // The multiplier for each level can also fine-tuned using CompactionTableSizeMultiplierPerLevel. + // + // The default value is 2MiB. + CompactionTableSize int + + // CompactionTableSizeMultiplier defines multiplier for CompactionTableSize. + // + // The default value is 1. + CompactionTableSizeMultiplier float64 + + // CompactionTableSizeMultiplierPerLevel defines per-level multiplier for + // CompactionTableSize. + // Use zero to skip a level. + // + // The default value is nil. + CompactionTableSizeMultiplierPerLevel []float64 + + // CompactionTotalSize limits total size of 'sorted table' for each level. + // The limits for each level will be calculated as: + // CompactionTotalSize * (CompactionTotalSizeMultiplier ^ Level) + // The multiplier for each level can also fine-tuned using + // CompactionTotalSizeMultiplierPerLevel. + // + // The default value is 10MiB. + CompactionTotalSize int + + // CompactionTotalSizeMultiplier defines multiplier for CompactionTotalSize. + // + // The default value is 10. + CompactionTotalSizeMultiplier float64 + + // CompactionTotalSizeMultiplierPerLevel defines per-level multiplier for + // CompactionTotalSize. + // Use zero to skip a level. + // + // The default value is nil. + CompactionTotalSizeMultiplierPerLevel []float64 + + // Comparer defines a total ordering over the space of []byte keys: a 'less + // than' relationship. The same comparison algorithm must be used for reads + // and writes over the lifetime of the DB. + // + // The default value uses the same ordering as bytes.Compare. + Comparer comparer.Comparer + + // Compression defines the 'sorted table' block compression to use. + // + // The default value (DefaultCompression) uses snappy compression. + Compression Compression + + // DisableBlockCache allows disable use of cache.Cache functionality on + // 'sorted table' block. + // + // The default value is false. + DisableBlockCache bool + + // DisableCompactionBackoff allows disable compaction retry backoff. + // + // The default value is false. + DisableCompactionBackoff bool + + // ErrorIfExist defines whether an error should returned if the DB already + // exist. + // + // The default value is false. + ErrorIfExist bool + + // ErrorIfMissing defines whether an error should returned if the DB is + // missing. If false then the database will be created if missing, otherwise + // an error will be returned. + // + // The default value is false. + ErrorIfMissing bool + + // Filter defines an 'effective filter' to use. An 'effective filter' + // if defined will be used to generate per-table filter block. + // The filter name will be stored on disk. + // During reads LevelDB will try to find matching filter from + // 'effective filter' and 'alternative filters'. + // + // Filter can be changed after a DB has been created. It is recommended + // to put old filter to the 'alternative filters' to mitigate lack of + // filter during transition period. + // + // A filter is used to reduce disk reads when looking for a specific key. + // + // The default value is nil. + Filter filter.Filter + + // MaxMemCompationLevel defines maximum level a newly compacted 'memdb' + // will be pushed into if doesn't creates overlap. This should less than + // NumLevel. Use -1 for level-0. + // + // The default is 2. + MaxMemCompationLevel int + + // NumLevel defines number of database level. The level shouldn't changed + // between opens, or the database will panic. + // + // The default is 7. + NumLevel int + + // OpenFilesCacher provides cache algorithm for open files caching. + // Specify NoCacher to disable caching algorithm. + // + // The default value is LRUCacher. + OpenFilesCacher Cacher + + // OpenFilesCacheCapacity defines the capacity of the open files caching. + // Use -1 for zero, this has same effect with specifying NoCacher to OpenFilesCacher. + // + // The default value is 500. + OpenFilesCacheCapacity int + + // Strict defines the DB strict level. + Strict Strict + + // WriteBuffer defines maximum size of a 'memdb' before flushed to + // 'sorted table'. 'memdb' is an in-memory DB backed by an on-disk + // unsorted journal. + // + // LevelDB may held up to two 'memdb' at the same time. + // + // The default value is 4MiB. + WriteBuffer int + + // WriteL0StopTrigger defines number of 'sorted table' at level-0 that will + // pause write. + // + // The default value is 12. + WriteL0PauseTrigger int + + // WriteL0SlowdownTrigger defines number of 'sorted table' at level-0 that + // will trigger write slowdown. + // + // The default value is 8. + WriteL0SlowdownTrigger int +} + +func (o *Options) GetAltFilters() []filter.Filter { + if o == nil { + return nil + } + return o.AltFilters +} + +func (o *Options) GetBlockCacher() Cacher { + if o == nil || o.BlockCacher == nil { + return DefaultBlockCacher + } else if o.BlockCacher == NoCacher { + return nil + } + return o.BlockCacher +} + +func (o *Options) GetBlockCacheCapacity() int { + if o == nil || o.BlockCacheCapacity <= 0 { + return DefaultBlockCacheCapacity + } else if o.BlockCacheCapacity == -1 { + return 0 + } + return o.BlockCacheCapacity +} + +func (o *Options) GetBlockRestartInterval() int { + if o == nil || o.BlockRestartInterval <= 0 { + return DefaultBlockRestartInterval + } + return o.BlockRestartInterval +} + +func (o *Options) GetBlockSize() int { + if o == nil || o.BlockSize <= 0 { + return DefaultBlockSize + } + return o.BlockSize +} + +func (o *Options) GetCompactionExpandLimit(level int) int { + factor := DefaultCompactionExpandLimitFactor + if o != nil && o.CompactionExpandLimitFactor > 0 { + factor = o.CompactionExpandLimitFactor + } + return o.GetCompactionTableSize(level+1) * factor +} + +func (o *Options) GetCompactionGPOverlaps(level int) int { + factor := DefaultCompactionGPOverlapsFactor + if o != nil && o.CompactionGPOverlapsFactor > 0 { + factor = o.CompactionGPOverlapsFactor + } + return o.GetCompactionTableSize(level+2) * factor +} + +func (o *Options) GetCompactionL0Trigger() int { + if o == nil || o.CompactionL0Trigger == 0 { + return DefaultCompactionL0Trigger + } + return o.CompactionL0Trigger +} + +func (o *Options) GetCompactionSourceLimit(level int) int { + factor := DefaultCompactionSourceLimitFactor + if o != nil && o.CompactionSourceLimitFactor > 0 { + factor = o.CompactionSourceLimitFactor + } + return o.GetCompactionTableSize(level+1) * factor +} + +func (o *Options) GetCompactionTableSize(level int) int { + var ( + base = DefaultCompactionTableSize + mult float64 + ) + if o != nil { + if o.CompactionTableSize > 0 { + base = o.CompactionTableSize + } + if len(o.CompactionTableSizeMultiplierPerLevel) > level && o.CompactionTableSizeMultiplierPerLevel[level] > 0 { + mult = o.CompactionTableSizeMultiplierPerLevel[level] + } else if o.CompactionTableSizeMultiplier > 0 { + mult = math.Pow(o.CompactionTableSizeMultiplier, float64(level)) + } + } + if mult == 0 { + mult = math.Pow(DefaultCompactionTableSizeMultiplier, float64(level)) + } + return int(float64(base) * mult) +} + +func (o *Options) GetCompactionTotalSize(level int) int64 { + var ( + base = DefaultCompactionTotalSize + mult float64 + ) + if o != nil { + if o.CompactionTotalSize > 0 { + base = o.CompactionTotalSize + } + if len(o.CompactionTotalSizeMultiplierPerLevel) > level && o.CompactionTotalSizeMultiplierPerLevel[level] > 0 { + mult = o.CompactionTotalSizeMultiplierPerLevel[level] + } else if o.CompactionTotalSizeMultiplier > 0 { + mult = math.Pow(o.CompactionTotalSizeMultiplier, float64(level)) + } + } + if mult == 0 { + mult = math.Pow(DefaultCompactionTotalSizeMultiplier, float64(level)) + } + return int64(float64(base) * mult) +} + +func (o *Options) GetComparer() comparer.Comparer { + if o == nil || o.Comparer == nil { + return comparer.DefaultComparer + } + return o.Comparer +} + +func (o *Options) GetCompression() Compression { + if o == nil || o.Compression <= DefaultCompression || o.Compression >= nCompression { + return DefaultCompressionType + } + return o.Compression +} + +func (o *Options) GetDisableCompactionBackoff() bool { + if o == nil { + return false + } + return o.DisableCompactionBackoff +} + +func (o *Options) GetErrorIfExist() bool { + if o == nil { + return false + } + return o.ErrorIfExist +} + +func (o *Options) GetErrorIfMissing() bool { + if o == nil { + return false + } + return o.ErrorIfMissing +} + +func (o *Options) GetFilter() filter.Filter { + if o == nil { + return nil + } + return o.Filter +} + +func (o *Options) GetMaxMemCompationLevel() int { + level := DefaultMaxMemCompationLevel + if o != nil { + if o.MaxMemCompationLevel > 0 { + level = o.MaxMemCompationLevel + } else if o.MaxMemCompationLevel == -1 { + level = 0 + } + } + if level >= o.GetNumLevel() { + return o.GetNumLevel() - 1 + } + return level +} + +func (o *Options) GetNumLevel() int { + if o == nil || o.NumLevel <= 0 { + return DefaultNumLevel + } + return o.NumLevel +} + +func (o *Options) GetOpenFilesCacher() Cacher { + if o == nil || o.OpenFilesCacher == nil { + return DefaultOpenFilesCacher + } + if o.OpenFilesCacher == NoCacher { + return nil + } + return o.OpenFilesCacher +} + +func (o *Options) GetOpenFilesCacheCapacity() int { + if o == nil || o.OpenFilesCacheCapacity <= 0 { + return DefaultOpenFilesCacheCapacity + } else if o.OpenFilesCacheCapacity == -1 { + return 0 + } + return o.OpenFilesCacheCapacity +} + +func (o *Options) GetStrict(strict Strict) bool { + if o == nil || o.Strict == 0 { + return DefaultStrict&strict != 0 + } + return o.Strict&strict != 0 +} + +func (o *Options) GetWriteBuffer() int { + if o == nil || o.WriteBuffer <= 0 { + return DefaultWriteBuffer + } + return o.WriteBuffer +} + +func (o *Options) GetWriteL0PauseTrigger() int { + if o == nil || o.WriteL0PauseTrigger == 0 { + return DefaultWriteL0PauseTrigger + } + return o.WriteL0PauseTrigger +} + +func (o *Options) GetWriteL0SlowdownTrigger() int { + if o == nil || o.WriteL0SlowdownTrigger == 0 { + return DefaultWriteL0SlowdownTrigger + } + return o.WriteL0SlowdownTrigger +} + +// ReadOptions holds the optional parameters for 'read operation'. The +// 'read operation' includes Get, Find and NewIterator. +type ReadOptions struct { + // DontFillCache defines whether block reads for this 'read operation' + // should be cached. If false then the block will be cached. This does + // not affects already cached block. + // + // The default value is false. + DontFillCache bool + + // Strict will be OR'ed with global DB 'strict level' unless StrictOverride + // is present. Currently only StrictReader that has effect here. + Strict Strict +} + +func (ro *ReadOptions) GetDontFillCache() bool { + if ro == nil { + return false + } + return ro.DontFillCache +} + +func (ro *ReadOptions) GetStrict(strict Strict) bool { + if ro == nil { + return false + } + return ro.Strict&strict != 0 +} + +// WriteOptions holds the optional parameters for 'write operation'. The +// 'write operation' includes Write, Put and Delete. +type WriteOptions struct { + // Sync is whether to sync underlying writes from the OS buffer cache + // through to actual disk, if applicable. Setting Sync can result in + // slower writes. + // + // If false, and the machine crashes, then some recent writes may be lost. + // Note that if it is just the process that crashes (and the machine does + // not) then no writes will be lost. + // + // In other words, Sync being false has the same semantics as a write + // system call. Sync being true means write followed by fsync. + // + // The default value is false. + Sync bool +} + +func (wo *WriteOptions) GetSync() bool { + if wo == nil { + return false + } + return wo.Sync +} + +func GetStrict(o *Options, ro *ReadOptions, strict Strict) bool { + if ro.GetStrict(StrictOverride) { + return ro.GetStrict(strict) + } else { + return o.GetStrict(strict) || ro.GetStrict(strict) + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/options.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/options.go new file mode 100644 index 0000000000..a3d84ef60d --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/options.go @@ -0,0 +1,92 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/opt" +) + +func dupOptions(o *opt.Options) *opt.Options { + newo := &opt.Options{} + if o != nil { + *newo = *o + } + if newo.Strict == 0 { + newo.Strict = opt.DefaultStrict + } + return newo +} + +func (s *session) setOptions(o *opt.Options) { + no := dupOptions(o) + // Alternative filters. + if filters := o.GetAltFilters(); len(filters) > 0 { + no.AltFilters = make([]filter.Filter, len(filters)) + for i, filter := range filters { + no.AltFilters[i] = &iFilter{filter} + } + } + // Comparer. + s.icmp = &iComparer{o.GetComparer()} + no.Comparer = s.icmp + // Filter. + if filter := o.GetFilter(); filter != nil { + no.Filter = &iFilter{filter} + } + + s.o = &cachedOptions{Options: no} + s.o.cache() +} + +type cachedOptions struct { + *opt.Options + + compactionExpandLimit []int + compactionGPOverlaps []int + compactionSourceLimit []int + compactionTableSize []int + compactionTotalSize []int64 +} + +func (co *cachedOptions) cache() { + numLevel := co.Options.GetNumLevel() + + co.compactionExpandLimit = make([]int, numLevel) + co.compactionGPOverlaps = make([]int, numLevel) + co.compactionSourceLimit = make([]int, numLevel) + co.compactionTableSize = make([]int, numLevel) + co.compactionTotalSize = make([]int64, numLevel) + + for level := 0; level < numLevel; level++ { + co.compactionExpandLimit[level] = co.Options.GetCompactionExpandLimit(level) + co.compactionGPOverlaps[level] = co.Options.GetCompactionGPOverlaps(level) + co.compactionSourceLimit[level] = co.Options.GetCompactionSourceLimit(level) + co.compactionTableSize[level] = co.Options.GetCompactionTableSize(level) + co.compactionTotalSize[level] = co.Options.GetCompactionTotalSize(level) + } +} + +func (co *cachedOptions) GetCompactionExpandLimit(level int) int { + return co.compactionExpandLimit[level] +} + +func (co *cachedOptions) GetCompactionGPOverlaps(level int) int { + return co.compactionGPOverlaps[level] +} + +func (co *cachedOptions) GetCompactionSourceLimit(level int) int { + return co.compactionSourceLimit[level] +} + +func (co *cachedOptions) GetCompactionTableSize(level int) int { + return co.compactionTableSize[level] +} + +func (co *cachedOptions) GetCompactionTotalSize(level int) int64 { + return co.compactionTotalSize[level] +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session.go new file mode 100644 index 0000000000..b3906f7fc8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session.go @@ -0,0 +1,455 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "io" + "os" + "sync" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type ErrManifestCorrupted struct { + Field string + Reason string +} + +func (e *ErrManifestCorrupted) Error() string { + return fmt.Sprintf("leveldb: manifest corrupted (field '%s'): %s", e.Field, e.Reason) +} + +func newErrManifestCorrupted(f storage.File, field, reason string) error { + return errors.NewErrCorrupted(f, &ErrManifestCorrupted{field, reason}) +} + +// session represent a persistent database session. +type session struct { + // Need 64-bit alignment. + stNextFileNum uint64 // current unused file number + stJournalNum uint64 // current journal file number; need external synchronization + stPrevJournalNum uint64 // prev journal file number; no longer used; for compatibility with older version of leveldb + stSeqNum uint64 // last mem compacted seq; need external synchronization + stTempFileNum uint64 + + stor storage.Storage + storLock util.Releaser + o *cachedOptions + icmp *iComparer + tops *tOps + + manifest *journal.Writer + manifestWriter storage.Writer + manifestFile storage.File + + stCompPtrs []iKey // compaction pointers; need external synchronization + stVersion *version // current version + vmu sync.Mutex +} + +// Creates new initialized session instance. +func newSession(stor storage.Storage, o *opt.Options) (s *session, err error) { + if stor == nil { + return nil, os.ErrInvalid + } + storLock, err := stor.Lock() + if err != nil { + return + } + s = &session{ + stor: stor, + storLock: storLock, + stCompPtrs: make([]iKey, o.GetNumLevel()), + } + s.setOptions(o) + s.tops = newTableOps(s) + s.setVersion(newVersion(s)) + s.log("log@legend F·NumFile S·FileSize N·Entry C·BadEntry B·BadBlock Ke·KeyError D·DroppedEntry L·Level Q·SeqNum T·TimeElapsed") + return +} + +// Close session. +func (s *session) close() { + s.tops.close() + if s.manifest != nil { + s.manifest.Close() + } + if s.manifestWriter != nil { + s.manifestWriter.Close() + } + s.manifest = nil + s.manifestWriter = nil + s.manifestFile = nil + s.stVersion = nil +} + +// Release session lock. +func (s *session) release() { + s.storLock.Release() +} + +// Create a new database session; need external synchronization. +func (s *session) create() error { + // create manifest + return s.newManifest(nil, nil) +} + +// Recover a database session; need external synchronization. +func (s *session) recover() (err error) { + defer func() { + if os.IsNotExist(err) { + // Don't return os.ErrNotExist if the underlying storage contains + // other files that belong to LevelDB. So the DB won't get trashed. + if files, _ := s.stor.GetFiles(storage.TypeAll); len(files) > 0 { + err = &errors.ErrCorrupted{File: &storage.FileInfo{Type: storage.TypeManifest}, Err: &errors.ErrMissingFiles{}} + } + } + }() + + m, err := s.stor.GetManifest() + if err != nil { + return + } + + reader, err := m.Open() + if err != nil { + return + } + defer reader.Close() + strict := s.o.GetStrict(opt.StrictManifest) + jr := journal.NewReader(reader, dropper{s, m}, strict, true) + + staging := s.stVersion.newStaging() + rec := &sessionRecord{numLevel: s.o.GetNumLevel()} + for { + var r io.Reader + r, err = jr.Next() + if err != nil { + if err == io.EOF { + err = nil + break + } + return errors.SetFile(err, m) + } + + err = rec.decode(r) + if err == nil { + // save compact pointers + for _, r := range rec.compPtrs { + s.stCompPtrs[r.level] = iKey(r.ikey) + } + // commit record to version staging + staging.commit(rec) + } else { + err = errors.SetFile(err, m) + if strict || !errors.IsCorrupted(err) { + return + } else { + s.logf("manifest error: %v (skipped)", errors.SetFile(err, m)) + } + } + rec.resetCompPtrs() + rec.resetAddedTables() + rec.resetDeletedTables() + } + + switch { + case !rec.has(recComparer): + return newErrManifestCorrupted(m, "comparer", "missing") + case rec.comparer != s.icmp.uName(): + return newErrManifestCorrupted(m, "comparer", fmt.Sprintf("mismatch: want '%s', got '%s'", s.icmp.uName(), rec.comparer)) + case !rec.has(recNextFileNum): + return newErrManifestCorrupted(m, "next-file-num", "missing") + case !rec.has(recJournalNum): + return newErrManifestCorrupted(m, "journal-file-num", "missing") + case !rec.has(recSeqNum): + return newErrManifestCorrupted(m, "seq-num", "missing") + } + + s.manifestFile = m + s.setVersion(staging.finish()) + s.setNextFileNum(rec.nextFileNum) + s.recordCommited(rec) + return nil +} + +// Commit session; need external synchronization. +func (s *session) commit(r *sessionRecord) (err error) { + v := s.version() + defer v.release() + + // spawn new version based on current version + nv := v.spawn(r) + + if s.manifest == nil { + // manifest journal writer not yet created, create one + err = s.newManifest(r, nv) + } else { + err = s.flushManifest(r) + } + + // finally, apply new version if no error rise + if err == nil { + s.setVersion(nv) + } + + return +} + +// Pick a compaction based on current state; need external synchronization. +func (s *session) pickCompaction() *compaction { + v := s.version() + + var level int + var t0 tFiles + if v.cScore >= 1 { + level = v.cLevel + cptr := s.stCompPtrs[level] + tables := v.tables[level] + for _, t := range tables { + if cptr == nil || s.icmp.Compare(t.imax, cptr) > 0 { + t0 = append(t0, t) + break + } + } + if len(t0) == 0 { + t0 = append(t0, tables[0]) + } + } else { + if p := atomic.LoadPointer(&v.cSeek); p != nil { + ts := (*tSet)(p) + level = ts.level + t0 = append(t0, ts.table) + } else { + v.release() + return nil + } + } + + return newCompaction(s, v, level, t0) +} + +// Create compaction from given level and range; need external synchronization. +func (s *session) getCompactionRange(level int, umin, umax []byte) *compaction { + v := s.version() + + t0 := v.tables[level].getOverlaps(nil, s.icmp, umin, umax, level == 0) + if len(t0) == 0 { + v.release() + return nil + } + + // Avoid compacting too much in one shot in case the range is large. + // But we cannot do this for level-0 since level-0 files can overlap + // and we must not pick one file and drop another older file if the + // two files overlap. + if level > 0 { + limit := uint64(v.s.o.GetCompactionSourceLimit(level)) + total := uint64(0) + for i, t := range t0 { + total += t.size + if total >= limit { + s.logf("table@compaction limiting F·%d -> F·%d", len(t0), i+1) + t0 = t0[:i+1] + break + } + } + } + + return newCompaction(s, v, level, t0) +} + +func newCompaction(s *session, v *version, level int, t0 tFiles) *compaction { + c := &compaction{ + s: s, + v: v, + level: level, + tables: [2]tFiles{t0, nil}, + maxGPOverlaps: uint64(s.o.GetCompactionGPOverlaps(level)), + tPtrs: make([]int, s.o.GetNumLevel()), + } + c.expand() + c.save() + return c +} + +// compaction represent a compaction state. +type compaction struct { + s *session + v *version + + level int + tables [2]tFiles + maxGPOverlaps uint64 + + gp tFiles + gpi int + seenKey bool + gpOverlappedBytes uint64 + imin, imax iKey + tPtrs []int + released bool + + snapGPI int + snapSeenKey bool + snapGPOverlappedBytes uint64 + snapTPtrs []int +} + +func (c *compaction) save() { + c.snapGPI = c.gpi + c.snapSeenKey = c.seenKey + c.snapGPOverlappedBytes = c.gpOverlappedBytes + c.snapTPtrs = append(c.snapTPtrs[:0], c.tPtrs...) +} + +func (c *compaction) restore() { + c.gpi = c.snapGPI + c.seenKey = c.snapSeenKey + c.gpOverlappedBytes = c.snapGPOverlappedBytes + c.tPtrs = append(c.tPtrs[:0], c.snapTPtrs...) +} + +func (c *compaction) release() { + if !c.released { + c.released = true + c.v.release() + } +} + +// Expand compacted tables; need external synchronization. +func (c *compaction) expand() { + limit := uint64(c.s.o.GetCompactionExpandLimit(c.level)) + vt0, vt1 := c.v.tables[c.level], c.v.tables[c.level+1] + + t0, t1 := c.tables[0], c.tables[1] + imin, imax := t0.getRange(c.s.icmp) + // We expand t0 here just incase ukey hop across tables. + t0 = vt0.getOverlaps(t0, c.s.icmp, imin.ukey(), imax.ukey(), c.level == 0) + if len(t0) != len(c.tables[0]) { + imin, imax = t0.getRange(c.s.icmp) + } + t1 = vt1.getOverlaps(t1, c.s.icmp, imin.ukey(), imax.ukey(), false) + // Get entire range covered by compaction. + amin, amax := append(t0, t1...).getRange(c.s.icmp) + + // See if we can grow the number of inputs in "level" without + // changing the number of "level+1" files we pick up. + if len(t1) > 0 { + exp0 := vt0.getOverlaps(nil, c.s.icmp, amin.ukey(), amax.ukey(), c.level == 0) + if len(exp0) > len(t0) && t1.size()+exp0.size() < limit { + xmin, xmax := exp0.getRange(c.s.icmp) + exp1 := vt1.getOverlaps(nil, c.s.icmp, xmin.ukey(), xmax.ukey(), false) + if len(exp1) == len(t1) { + c.s.logf("table@compaction expanding L%d+L%d (F·%d S·%s)+(F·%d S·%s) -> (F·%d S·%s)+(F·%d S·%s)", + c.level, c.level+1, len(t0), shortenb(int(t0.size())), len(t1), shortenb(int(t1.size())), + len(exp0), shortenb(int(exp0.size())), len(exp1), shortenb(int(exp1.size()))) + imin, imax = xmin, xmax + t0, t1 = exp0, exp1 + amin, amax = append(t0, t1...).getRange(c.s.icmp) + } + } + } + + // Compute the set of grandparent files that overlap this compaction + // (parent == level+1; grandparent == level+2) + if c.level+2 < c.s.o.GetNumLevel() { + c.gp = c.v.tables[c.level+2].getOverlaps(c.gp, c.s.icmp, amin.ukey(), amax.ukey(), false) + } + + c.tables[0], c.tables[1] = t0, t1 + c.imin, c.imax = imin, imax +} + +// Check whether compaction is trivial. +func (c *compaction) trivial() bool { + return len(c.tables[0]) == 1 && len(c.tables[1]) == 0 && c.gp.size() <= c.maxGPOverlaps +} + +func (c *compaction) baseLevelForKey(ukey []byte) bool { + for level, tables := range c.v.tables[c.level+2:] { + for c.tPtrs[level] < len(tables) { + t := tables[c.tPtrs[level]] + if c.s.icmp.uCompare(ukey, t.imax.ukey()) <= 0 { + // We've advanced far enough. + if c.s.icmp.uCompare(ukey, t.imin.ukey()) >= 0 { + // Key falls in this file's range, so definitely not base level. + return false + } + break + } + c.tPtrs[level]++ + } + } + return true +} + +func (c *compaction) shouldStopBefore(ikey iKey) bool { + for ; c.gpi < len(c.gp); c.gpi++ { + gp := c.gp[c.gpi] + if c.s.icmp.Compare(ikey, gp.imax) <= 0 { + break + } + if c.seenKey { + c.gpOverlappedBytes += gp.size + } + } + c.seenKey = true + + if c.gpOverlappedBytes > c.maxGPOverlaps { + // Too much overlap for current output; start new output. + c.gpOverlappedBytes = 0 + return true + } + return false +} + +// Creates an iterator. +func (c *compaction) newIterator() iterator.Iterator { + // Creates iterator slice. + icap := len(c.tables) + if c.level == 0 { + // Special case for level-0 + icap = len(c.tables[0]) + 1 + } + its := make([]iterator.Iterator, 0, icap) + + // Options. + ro := &opt.ReadOptions{ + DontFillCache: true, + Strict: opt.StrictOverride, + } + strict := c.s.o.GetStrict(opt.StrictCompaction) + if strict { + ro.Strict |= opt.StrictReader + } + + for i, tables := range c.tables { + if len(tables) == 0 { + continue + } + + // Level-0 is not sorted and may overlaps each other. + if c.level+i == 0 { + for _, t := range tables { + its = append(its, c.s.tops.newIterator(t, nil, ro)) + } + } else { + it := iterator.NewIndexedIterator(tables.newIndexIterator(c.s.tops, c.s.icmp, nil, ro), strict) + its = append(its, it) + } + } + + return iterator.NewMergedIterator(its, c.s.icmp, strict) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record.go new file mode 100644 index 0000000000..1bdcc68f5f --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record.go @@ -0,0 +1,313 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bufio" + "encoding/binary" + "io" + "strings" + + "github.com/syndtr/goleveldb/leveldb/errors" +) + +type byteReader interface { + io.Reader + io.ByteReader +} + +// These numbers are written to disk and should not be changed. +const ( + recComparer = 1 + recJournalNum = 2 + recNextFileNum = 3 + recSeqNum = 4 + recCompPtr = 5 + recDelTable = 6 + recAddTable = 7 + // 8 was used for large value refs + recPrevJournalNum = 9 +) + +type cpRecord struct { + level int + ikey iKey +} + +type atRecord struct { + level int + num uint64 + size uint64 + imin iKey + imax iKey +} + +type dtRecord struct { + level int + num uint64 +} + +type sessionRecord struct { + numLevel int + + hasRec int + comparer string + journalNum uint64 + prevJournalNum uint64 + nextFileNum uint64 + seqNum uint64 + compPtrs []cpRecord + addedTables []atRecord + deletedTables []dtRecord + + scratch [binary.MaxVarintLen64]byte + err error +} + +func (p *sessionRecord) has(rec int) bool { + return p.hasRec&(1<= uint64(p.numLevel) { + p.err = errors.NewErrCorrupted(nil, &ErrManifestCorrupted{field, "invalid level number"}) + return 0 + } + return int(x) +} + +func (p *sessionRecord) decode(r io.Reader) error { + br, ok := r.(byteReader) + if !ok { + br = bufio.NewReader(r) + } + p.err = nil + for p.err == nil { + rec := p.readUvarintMayEOF("field-header", br, true) + if p.err != nil { + if p.err == io.EOF { + return nil + } + return p.err + } + switch rec { + case recComparer: + x := p.readBytes("comparer", br) + if p.err == nil { + p.setComparer(string(x)) + } + case recJournalNum: + x := p.readUvarint("journal-num", br) + if p.err == nil { + p.setJournalNum(x) + } + case recPrevJournalNum: + x := p.readUvarint("prev-journal-num", br) + if p.err == nil { + p.setPrevJournalNum(x) + } + case recNextFileNum: + x := p.readUvarint("next-file-num", br) + if p.err == nil { + p.setNextFileNum(x) + } + case recSeqNum: + x := p.readUvarint("seq-num", br) + if p.err == nil { + p.setSeqNum(x) + } + case recCompPtr: + level := p.readLevel("comp-ptr.level", br) + ikey := p.readBytes("comp-ptr.ikey", br) + if p.err == nil { + p.addCompPtr(level, iKey(ikey)) + } + case recAddTable: + level := p.readLevel("add-table.level", br) + num := p.readUvarint("add-table.num", br) + size := p.readUvarint("add-table.size", br) + imin := p.readBytes("add-table.imin", br) + imax := p.readBytes("add-table.imax", br) + if p.err == nil { + p.addTable(level, num, size, imin, imax) + } + case recDelTable: + level := p.readLevel("del-table.level", br) + num := p.readUvarint("del-table.num", br) + if p.err == nil { + p.delTable(level, num) + } + } + } + + return p.err +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record_test.go new file mode 100644 index 0000000000..c0c035ae39 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_record_test.go @@ -0,0 +1,64 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "bytes" + "testing" + + "github.com/syndtr/goleveldb/leveldb/opt" +) + +func decodeEncode(v *sessionRecord) (res bool, err error) { + b := new(bytes.Buffer) + err = v.encode(b) + if err != nil { + return + } + v2 := &sessionRecord{numLevel: opt.DefaultNumLevel} + err = v.decode(b) + if err != nil { + return + } + b2 := new(bytes.Buffer) + err = v2.encode(b2) + if err != nil { + return + } + return bytes.Equal(b.Bytes(), b2.Bytes()), nil +} + +func TestSessionRecord_EncodeDecode(t *testing.T) { + big := uint64(1) << 50 + v := &sessionRecord{numLevel: opt.DefaultNumLevel} + i := uint64(0) + test := func() { + res, err := decodeEncode(v) + if err != nil { + t.Fatalf("error when testing encode/decode sessionRecord: %v", err) + } + if !res { + t.Error("encode/decode test failed at iteration:", i) + } + } + + for ; i < 4; i++ { + test() + v.addTable(3, big+300+i, big+400+i, + newIkey([]byte("foo"), big+500+1, ktVal), + newIkey([]byte("zoo"), big+600+1, ktDel)) + v.delTable(4, big+700+i) + v.addCompPtr(int(i), newIkey([]byte("x"), big+900+1, ktVal)) + } + + v.setComparer("foo") + v.setJournalNum(big + 100) + v.setPrevJournalNum(big + 99) + v.setNextFileNum(big + 200) + v.setSeqNum(big + 1000) + test() +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_util.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_util.go new file mode 100644 index 0000000000..007c02cde9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/session_util.go @@ -0,0 +1,249 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/journal" + "github.com/syndtr/goleveldb/leveldb/storage" +) + +// Logging. + +type dropper struct { + s *session + file storage.File +} + +func (d dropper) Drop(err error) { + if e, ok := err.(*journal.ErrCorrupted); ok { + d.s.logf("journal@drop %s-%d S·%s %q", d.file.Type(), d.file.Num(), shortenb(e.Size), e.Reason) + } else { + d.s.logf("journal@drop %s-%d %q", d.file.Type(), d.file.Num(), err) + } +} + +func (s *session) log(v ...interface{}) { s.stor.Log(fmt.Sprint(v...)) } +func (s *session) logf(format string, v ...interface{}) { s.stor.Log(fmt.Sprintf(format, v...)) } + +// File utils. + +func (s *session) getJournalFile(num uint64) storage.File { + return s.stor.GetFile(num, storage.TypeJournal) +} + +func (s *session) getTableFile(num uint64) storage.File { + return s.stor.GetFile(num, storage.TypeTable) +} + +func (s *session) getFiles(t storage.FileType) ([]storage.File, error) { + return s.stor.GetFiles(t) +} + +func (s *session) newTemp() storage.File { + num := atomic.AddUint64(&s.stTempFileNum, 1) - 1 + return s.stor.GetFile(num, storage.TypeTemp) +} + +func (s *session) tableFileFromRecord(r atRecord) *tFile { + return newTableFile(s.getTableFile(r.num), r.size, r.imin, r.imax) +} + +// Session state. + +// Get current version. This will incr version ref, must call +// version.release (exactly once) after use. +func (s *session) version() *version { + s.vmu.Lock() + defer s.vmu.Unlock() + s.stVersion.ref++ + return s.stVersion +} + +// Set current version to v. +func (s *session) setVersion(v *version) { + s.vmu.Lock() + v.ref = 1 // Holds by session. + if old := s.stVersion; old != nil { + v.ref++ // Holds by old version. + old.next = v + old.releaseNB() + } + s.stVersion = v + s.vmu.Unlock() +} + +// Get current unused file number. +func (s *session) nextFileNum() uint64 { + return atomic.LoadUint64(&s.stNextFileNum) +} + +// Set current unused file number to num. +func (s *session) setNextFileNum(num uint64) { + atomic.StoreUint64(&s.stNextFileNum, num) +} + +// Mark file number as used. +func (s *session) markFileNum(num uint64) { + nextFileNum := num + 1 + for { + old, x := s.stNextFileNum, nextFileNum + if old > x { + x = old + } + if atomic.CompareAndSwapUint64(&s.stNextFileNum, old, x) { + break + } + } +} + +// Allocate a file number. +func (s *session) allocFileNum() uint64 { + return atomic.AddUint64(&s.stNextFileNum, 1) - 1 +} + +// Reuse given file number. +func (s *session) reuseFileNum(num uint64) { + for { + old, x := s.stNextFileNum, num + if old != x+1 { + x = old + } + if atomic.CompareAndSwapUint64(&s.stNextFileNum, old, x) { + break + } + } +} + +// Manifest related utils. + +// Fill given session record obj with current states; need external +// synchronization. +func (s *session) fillRecord(r *sessionRecord, snapshot bool) { + r.setNextFileNum(s.nextFileNum()) + + if snapshot { + if !r.has(recJournalNum) { + r.setJournalNum(s.stJournalNum) + } + + if !r.has(recSeqNum) { + r.setSeqNum(s.stSeqNum) + } + + for level, ik := range s.stCompPtrs { + if ik != nil { + r.addCompPtr(level, ik) + } + } + + r.setComparer(s.icmp.uName()) + } +} + +// Mark if record has been committed, this will update session state; +// need external synchronization. +func (s *session) recordCommited(r *sessionRecord) { + if r.has(recJournalNum) { + s.stJournalNum = r.journalNum + } + + if r.has(recPrevJournalNum) { + s.stPrevJournalNum = r.prevJournalNum + } + + if r.has(recSeqNum) { + s.stSeqNum = r.seqNum + } + + for _, p := range r.compPtrs { + s.stCompPtrs[p.level] = iKey(p.ikey) + } +} + +// Create a new manifest file; need external synchronization. +func (s *session) newManifest(rec *sessionRecord, v *version) (err error) { + num := s.allocFileNum() + file := s.stor.GetFile(num, storage.TypeManifest) + writer, err := file.Create() + if err != nil { + return + } + jw := journal.NewWriter(writer) + + if v == nil { + v = s.version() + defer v.release() + } + if rec == nil { + rec = &sessionRecord{numLevel: s.o.GetNumLevel()} + } + s.fillRecord(rec, true) + v.fillRecord(rec) + + defer func() { + if err == nil { + s.recordCommited(rec) + if s.manifest != nil { + s.manifest.Close() + } + if s.manifestWriter != nil { + s.manifestWriter.Close() + } + if s.manifestFile != nil { + s.manifestFile.Remove() + } + s.manifestFile = file + s.manifestWriter = writer + s.manifest = jw + } else { + writer.Close() + file.Remove() + s.reuseFileNum(num) + } + }() + + w, err := jw.Next() + if err != nil { + return + } + err = rec.encode(w) + if err != nil { + return + } + err = jw.Flush() + if err != nil { + return + } + err = s.stor.SetManifest(file) + return +} + +// Flush record to disk. +func (s *session) flushManifest(rec *sessionRecord) (err error) { + s.fillRecord(rec, false) + w, err := s.manifest.Next() + if err != nil { + return + } + err = rec.encode(w) + if err != nil { + return + } + err = s.manifest.Flush() + if err != nil { + return + } + err = s.manifestWriter.Sync() + if err != nil { + return + } + s.recordCommited(rec) + return +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go new file mode 100644 index 0000000000..46cc9d0701 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage.go @@ -0,0 +1,534 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reservefs. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +var errFileOpen = errors.New("leveldb/storage: file still open") + +type fileLock interface { + release() error +} + +type fileStorageLock struct { + fs *fileStorage +} + +func (lock *fileStorageLock) Release() { + fs := lock.fs + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.slock == lock { + fs.slock = nil + } + return +} + +// fileStorage is a file-system backed storage. +type fileStorage struct { + path string + + mu sync.Mutex + flock fileLock + slock *fileStorageLock + logw *os.File + buf []byte + // Opened file counter; if open < 0 means closed. + open int + day int +} + +// OpenFile returns a new filesytem-backed storage implementation with the given +// path. This also hold a file lock, so any subsequent attempt to open the same +// path will fail. +// +// The storage must be closed after use, by calling Close method. +func OpenFile(path string) (Storage, error) { + if err := os.MkdirAll(path, 0755); err != nil { + return nil, err + } + + flock, err := newFileLock(filepath.Join(path, "LOCK")) + if err != nil { + return nil, err + } + + defer func() { + if err != nil { + flock.release() + } + }() + + rename(filepath.Join(path, "LOG"), filepath.Join(path, "LOG.old")) + logw, err := os.OpenFile(filepath.Join(path, "LOG"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return nil, err + } + + fs := &fileStorage{path: path, flock: flock, logw: logw} + runtime.SetFinalizer(fs, (*fileStorage).Close) + return fs, nil +} + +func (fs *fileStorage) Lock() (util.Releaser, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return nil, ErrClosed + } + if fs.slock != nil { + return nil, ErrLocked + } + fs.slock = &fileStorageLock{fs: fs} + return fs.slock, nil +} + +func itoa(buf []byte, i int, wid int) []byte { + var u uint = uint(i) + if u == 0 && wid <= 1 { + return append(buf, '0') + } + + // Assemble decimal in reverse order. + var b [32]byte + bp := len(b) + for ; u > 0 || wid > 0; u /= 10 { + bp-- + wid-- + b[bp] = byte(u%10) + '0' + } + return append(buf, b[bp:]...) +} + +func (fs *fileStorage) printDay(t time.Time) { + if fs.day == t.Day() { + return + } + fs.day = t.Day() + fs.logw.Write([]byte("=============== " + t.Format("Jan 2, 2006 (MST)") + " ===============\n")) +} + +func (fs *fileStorage) doLog(t time.Time, str string) { + fs.printDay(t) + hour, min, sec := t.Clock() + msec := t.Nanosecond() / 1e3 + // time + fs.buf = itoa(fs.buf[:0], hour, 2) + fs.buf = append(fs.buf, ':') + fs.buf = itoa(fs.buf, min, 2) + fs.buf = append(fs.buf, ':') + fs.buf = itoa(fs.buf, sec, 2) + fs.buf = append(fs.buf, '.') + fs.buf = itoa(fs.buf, msec, 6) + fs.buf = append(fs.buf, ' ') + // write + fs.buf = append(fs.buf, []byte(str)...) + fs.buf = append(fs.buf, '\n') + fs.logw.Write(fs.buf) +} + +func (fs *fileStorage) Log(str string) { + t := time.Now() + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return + } + fs.doLog(t, str) +} + +func (fs *fileStorage) log(str string) { + fs.doLog(time.Now(), str) +} + +func (fs *fileStorage) GetFile(num uint64, t FileType) File { + return &file{fs: fs, num: num, t: t} +} + +func (fs *fileStorage) GetFiles(t FileType) (ff []File, err error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return nil, ErrClosed + } + dir, err := os.Open(fs.path) + if err != nil { + return + } + fnn, err := dir.Readdirnames(0) + // Close the dir first before checking for Readdirnames error. + if err := dir.Close(); err != nil { + fs.log(fmt.Sprintf("close dir: %v", err)) + } + if err != nil { + return + } + f := &file{fs: fs} + for _, fn := range fnn { + if f.parse(fn) && (f.t&t) != 0 { + ff = append(ff, f) + f = &file{fs: fs} + } + } + return +} + +func (fs *fileStorage) GetManifest() (f File, err error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return nil, ErrClosed + } + dir, err := os.Open(fs.path) + if err != nil { + return + } + fnn, err := dir.Readdirnames(0) + // Close the dir first before checking for Readdirnames error. + if err := dir.Close(); err != nil { + fs.log(fmt.Sprintf("close dir: %v", err)) + } + if err != nil { + return + } + // Find latest CURRENT file. + var rem []string + var pend bool + var cerr error + for _, fn := range fnn { + if strings.HasPrefix(fn, "CURRENT") { + pend1 := len(fn) > 7 + // Make sure it is valid name for a CURRENT file, otherwise skip it. + if pend1 { + if fn[7] != '.' || len(fn) < 9 { + fs.log(fmt.Sprintf("skipping %s: invalid file name", fn)) + continue + } + if _, e1 := strconv.ParseUint(fn[8:], 10, 0); e1 != nil { + fs.log(fmt.Sprintf("skipping %s: invalid file num: %v", fn, e1)) + continue + } + } + path := filepath.Join(fs.path, fn) + r, e1 := os.OpenFile(path, os.O_RDONLY, 0) + if e1 != nil { + return nil, e1 + } + b, e1 := ioutil.ReadAll(r) + if e1 != nil { + r.Close() + return nil, e1 + } + f1 := &file{fs: fs} + if len(b) < 1 || b[len(b)-1] != '\n' || !f1.parse(string(b[:len(b)-1])) { + fs.log(fmt.Sprintf("skipping %s: corrupted or incomplete", fn)) + if pend1 { + rem = append(rem, fn) + } + if !pend1 || cerr == nil { + cerr = fmt.Errorf("leveldb/storage: corrupted or incomplete %s file", fn) + } + } else if f != nil && f1.Num() < f.Num() { + fs.log(fmt.Sprintf("skipping %s: obsolete", fn)) + if pend1 { + rem = append(rem, fn) + } + } else { + f = f1 + pend = pend1 + } + if err := r.Close(); err != nil { + fs.log(fmt.Sprintf("close %s: %v", fn, err)) + } + } + } + // Don't remove any files if there is no valid CURRENT file. + if f == nil { + if cerr != nil { + err = cerr + } else { + err = os.ErrNotExist + } + return + } + // Rename pending CURRENT file to an effective CURRENT. + if pend { + path := fmt.Sprintf("%s.%d", filepath.Join(fs.path, "CURRENT"), f.Num()) + if err := rename(path, filepath.Join(fs.path, "CURRENT")); err != nil { + fs.log(fmt.Sprintf("CURRENT.%d -> CURRENT: %v", f.Num(), err)) + } + } + // Remove obsolete or incomplete pending CURRENT files. + for _, fn := range rem { + path := filepath.Join(fs.path, fn) + if err := os.Remove(path); err != nil { + fs.log(fmt.Sprintf("remove %s: %v", fn, err)) + } + } + return +} + +func (fs *fileStorage) SetManifest(f File) (err error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return ErrClosed + } + f2, ok := f.(*file) + if !ok || f2.t != TypeManifest { + return ErrInvalidFile + } + defer func() { + if err != nil { + fs.log(fmt.Sprintf("CURRENT: %v", err)) + } + }() + path := fmt.Sprintf("%s.%d", filepath.Join(fs.path, "CURRENT"), f2.Num()) + w, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + _, err = fmt.Fprintln(w, f2.name()) + // Close the file first. + if err := w.Close(); err != nil { + fs.log(fmt.Sprintf("close CURRENT.%d: %v", f2.num, err)) + } + if err != nil { + return err + } + return rename(path, filepath.Join(fs.path, "CURRENT")) +} + +func (fs *fileStorage) Close() error { + fs.mu.Lock() + defer fs.mu.Unlock() + if fs.open < 0 { + return ErrClosed + } + // Clear the finalizer. + runtime.SetFinalizer(fs, nil) + + if fs.open > 0 { + fs.log(fmt.Sprintf("refuse to close, %d files still open", fs.open)) + return fmt.Errorf("leveldb/storage: cannot close, %d files still open", fs.open) + } + fs.open = -1 + e1 := fs.logw.Close() + err := fs.flock.release() + if err == nil { + err = e1 + } + return err +} + +type fileWrap struct { + *os.File + f *file +} + +func (fw fileWrap) Sync() error { + if err := fw.File.Sync(); err != nil { + return err + } + if fw.f.Type() == TypeManifest { + // Also sync parent directory if file type is manifest. + // See: https://code.google.com/p/leveldb/issues/detail?id=190. + if err := syncDir(fw.f.fs.path); err != nil { + return err + } + } + return nil +} + +func (fw fileWrap) Close() error { + f := fw.f + f.fs.mu.Lock() + defer f.fs.mu.Unlock() + if !f.open { + return ErrClosed + } + f.open = false + f.fs.open-- + err := fw.File.Close() + if err != nil { + f.fs.log(fmt.Sprintf("close %s.%d: %v", f.Type(), f.Num(), err)) + } + return err +} + +type file struct { + fs *fileStorage + num uint64 + t FileType + open bool +} + +func (f *file) Open() (Reader, error) { + f.fs.mu.Lock() + defer f.fs.mu.Unlock() + if f.fs.open < 0 { + return nil, ErrClosed + } + if f.open { + return nil, errFileOpen + } + of, err := os.OpenFile(f.path(), os.O_RDONLY, 0) + if err != nil { + if f.hasOldName() && os.IsNotExist(err) { + of, err = os.OpenFile(f.oldPath(), os.O_RDONLY, 0) + if err == nil { + goto ok + } + } + return nil, err + } +ok: + f.open = true + f.fs.open++ + return fileWrap{of, f}, nil +} + +func (f *file) Create() (Writer, error) { + f.fs.mu.Lock() + defer f.fs.mu.Unlock() + if f.fs.open < 0 { + return nil, ErrClosed + } + if f.open { + return nil, errFileOpen + } + of, err := os.OpenFile(f.path(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return nil, err + } + f.open = true + f.fs.open++ + return fileWrap{of, f}, nil +} + +func (f *file) Replace(newfile File) error { + f.fs.mu.Lock() + defer f.fs.mu.Unlock() + if f.fs.open < 0 { + return ErrClosed + } + newfile2, ok := newfile.(*file) + if !ok { + return ErrInvalidFile + } + if f.open || newfile2.open { + return errFileOpen + } + return rename(newfile2.path(), f.path()) +} + +func (f *file) Type() FileType { + return f.t +} + +func (f *file) Num() uint64 { + return f.num +} + +func (f *file) Remove() error { + f.fs.mu.Lock() + defer f.fs.mu.Unlock() + if f.fs.open < 0 { + return ErrClosed + } + if f.open { + return errFileOpen + } + err := os.Remove(f.path()) + if err != nil { + f.fs.log(fmt.Sprintf("remove %s.%d: %v", f.Type(), f.Num(), err)) + } + // Also try remove file with old name, just in case. + if f.hasOldName() { + if e1 := os.Remove(f.oldPath()); !os.IsNotExist(e1) { + f.fs.log(fmt.Sprintf("remove %s.%d: %v (old name)", f.Type(), f.Num(), err)) + err = e1 + } + } + return err +} + +func (f *file) hasOldName() bool { + return f.t == TypeTable +} + +func (f *file) oldName() string { + switch f.t { + case TypeTable: + return fmt.Sprintf("%06d.sst", f.num) + } + return f.name() +} + +func (f *file) oldPath() string { + return filepath.Join(f.fs.path, f.oldName()) +} + +func (f *file) name() string { + switch f.t { + case TypeManifest: + return fmt.Sprintf("MANIFEST-%06d", f.num) + case TypeJournal: + return fmt.Sprintf("%06d.log", f.num) + case TypeTable: + return fmt.Sprintf("%06d.ldb", f.num) + case TypeTemp: + return fmt.Sprintf("%06d.tmp", f.num) + default: + panic("invalid file type") + } +} + +func (f *file) path() string { + return filepath.Join(f.fs.path, f.name()) +} + +func (f *file) parse(name string) bool { + var num uint64 + var tail string + _, err := fmt.Sscanf(name, "%d.%s", &num, &tail) + if err == nil { + switch tail { + case "log": + f.t = TypeJournal + case "ldb", "sst": + f.t = TypeTable + case "tmp": + f.t = TypeTemp + default: + return false + } + f.num = num + return true + } + n, _ := fmt.Sscanf(name, "MANIFEST-%d%s", &num, &tail) + if n == 1 { + f.t = TypeManifest + f.num = num + return true + } + + return false +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go new file mode 100644 index 0000000000..42940d769f --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_plan9.go @@ -0,0 +1,52 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "os" + "path/filepath" +) + +type plan9FileLock struct { + f *os.File +} + +func (fl *plan9FileLock) release() error { + return fl.f.Close() +} + +func newFileLock(path string) (fl fileLock, err error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, os.ModeExclusive|0644) + if err != nil { + return + } + fl = &plan9FileLock{f: f} + return +} + +func rename(oldpath, newpath string) error { + if _, err := os.Stat(newpath); err == nil { + if err := os.Remove(newpath); err != nil { + return err + } + } + + _, fname := filepath.Split(newpath) + return os.Rename(oldpath, fname) +} + +func syncDir(name string) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + if err := f.Sync(); err != nil { + return err + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go new file mode 100644 index 0000000000..102031bfd5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_solaris.go @@ -0,0 +1,68 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build solaris + +package storage + +import ( + "os" + "syscall" +) + +type unixFileLock struct { + f *os.File +} + +func (fl *unixFileLock) release() error { + if err := setFileLock(fl.f, false); err != nil { + return err + } + return fl.f.Close() +} + +func newFileLock(path string) (fl fileLock, err error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + return + } + err = setFileLock(f, true) + if err != nil { + f.Close() + return + } + fl = &unixFileLock{f: f} + return +} + +func setFileLock(f *os.File, lock bool) error { + flock := syscall.Flock_t{ + Type: syscall.F_UNLCK, + Start: 0, + Len: 0, + Whence: 1, + } + if lock { + flock.Type = syscall.F_WRLCK + } + return syscall.FcntlFlock(f.Fd(), syscall.F_SETLK, &flock) +} + +func rename(oldpath, newpath string) error { + return os.Rename(oldpath, newpath) +} + +func syncDir(name string) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + if err := f.Sync(); err != nil { + return err + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go new file mode 100644 index 0000000000..92abcbb7d0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_test.go @@ -0,0 +1,142 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +var cases = []struct { + oldName []string + name string + ftype FileType + num uint64 +}{ + {nil, "000100.log", TypeJournal, 100}, + {nil, "000000.log", TypeJournal, 0}, + {[]string{"000000.sst"}, "000000.ldb", TypeTable, 0}, + {nil, "MANIFEST-000002", TypeManifest, 2}, + {nil, "MANIFEST-000007", TypeManifest, 7}, + {nil, "18446744073709551615.log", TypeJournal, 18446744073709551615}, + {nil, "000100.tmp", TypeTemp, 100}, +} + +var invalidCases = []string{ + "", + "foo", + "foo-dx-100.log", + ".log", + "", + "manifest", + "CURREN", + "CURRENTX", + "MANIFES", + "MANIFEST", + "MANIFEST-", + "XMANIFEST-3", + "MANIFEST-3x", + "LOC", + "LOCKx", + "LO", + "LOGx", + "18446744073709551616.log", + "184467440737095516150.log", + "100", + "100.", + "100.lop", +} + +func TestFileStorage_CreateFileName(t *testing.T) { + for _, c := range cases { + f := &file{num: c.num, t: c.ftype} + if f.name() != c.name { + t.Errorf("invalid filename got '%s', want '%s'", f.name(), c.name) + } + } +} + +func TestFileStorage_ParseFileName(t *testing.T) { + for _, c := range cases { + for _, name := range append([]string{c.name}, c.oldName...) { + f := new(file) + if !f.parse(name) { + t.Errorf("cannot parse filename '%s'", name) + continue + } + if f.Type() != c.ftype { + t.Errorf("filename '%s' invalid type got '%d', want '%d'", name, f.Type(), c.ftype) + } + if f.Num() != c.num { + t.Errorf("filename '%s' invalid number got '%d', want '%d'", name, f.Num(), c.num) + } + } + } +} + +func TestFileStorage_InvalidFileName(t *testing.T) { + for _, name := range invalidCases { + f := new(file) + if f.parse(name) { + t.Errorf("filename '%s' should be invalid", name) + } + } +} + +func TestFileStorage_Locking(t *testing.T) { + path := filepath.Join(os.TempDir(), fmt.Sprintf("goleveldbtestfd-%d", os.Getuid())) + + _, err := os.Stat(path) + if err == nil { + err = os.RemoveAll(path) + if err != nil { + t.Fatal("RemoveAll: got error: ", err) + } + } + + p1, err := OpenFile(path) + if err != nil { + t.Fatal("OpenFile(1): got error: ", err) + } + + defer os.RemoveAll(path) + + p2, err := OpenFile(path) + if err != nil { + t.Logf("OpenFile(2): got error: %s (expected)", err) + } else { + p2.Close() + p1.Close() + t.Fatal("OpenFile(2): expect error") + } + + p1.Close() + + p3, err := OpenFile(path) + if err != nil { + t.Fatal("OpenFile(3): got error: ", err) + } + defer p3.Close() + + l, err := p3.Lock() + if err != nil { + t.Fatal("storage lock failed(1): ", err) + } + _, err = p3.Lock() + if err == nil { + t.Fatal("expect error for second storage lock attempt") + } else { + t.Logf("storage lock got error: %s (expected)", err) + } + l.Release() + _, err = p3.Lock() + if err != nil { + t.Fatal("storage lock failed(2): ", err) + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go new file mode 100644 index 0000000000..d0a604b7ab --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_unix.go @@ -0,0 +1,63 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package storage + +import ( + "os" + "syscall" +) + +type unixFileLock struct { + f *os.File +} + +func (fl *unixFileLock) release() error { + if err := setFileLock(fl.f, false); err != nil { + return err + } + return fl.f.Close() +} + +func newFileLock(path string) (fl fileLock, err error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + return + } + err = setFileLock(f, true) + if err != nil { + f.Close() + return + } + fl = &unixFileLock{f: f} + return +} + +func setFileLock(f *os.File, lock bool) error { + how := syscall.LOCK_UN + if lock { + how = syscall.LOCK_EX + } + return syscall.Flock(int(f.Fd()), how|syscall.LOCK_NB) +} + +func rename(oldpath, newpath string) error { + return os.Rename(oldpath, newpath) +} + +func syncDir(name string) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + if err := f.Sync(); err != nil { + return err + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go new file mode 100644 index 0000000000..50c3c454e7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/file_storage_windows.go @@ -0,0 +1,69 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "syscall" + "unsafe" +) + +var ( + modkernel32 = syscall.NewLazyDLL("kernel32.dll") + + procMoveFileExW = modkernel32.NewProc("MoveFileExW") +) + +const ( + _MOVEFILE_REPLACE_EXISTING = 1 +) + +type windowsFileLock struct { + fd syscall.Handle +} + +func (fl *windowsFileLock) release() error { + return syscall.Close(fl.fd) +} + +func newFileLock(path string) (fl fileLock, err error) { + pathp, err := syscall.UTF16PtrFromString(path) + if err != nil { + return + } + fd, err := syscall.CreateFile(pathp, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.CREATE_ALWAYS, syscall.FILE_ATTRIBUTE_NORMAL, 0) + if err != nil { + return + } + fl = &windowsFileLock{fd: fd} + return +} + +func moveFileEx(from *uint16, to *uint16, flags uint32) error { + r1, _, e1 := syscall.Syscall(procMoveFileExW.Addr(), 3, uintptr(unsafe.Pointer(from)), uintptr(unsafe.Pointer(to)), uintptr(flags)) + if r1 == 0 { + if e1 != 0 { + return error(e1) + } else { + return syscall.EINVAL + } + } + return nil +} + +func rename(oldpath, newpath string) error { + from, err := syscall.UTF16PtrFromString(oldpath) + if err != nil { + return err + } + to, err := syscall.UTF16PtrFromString(newpath) + if err != nil { + return err + } + return moveFileEx(from, to, _MOVEFILE_REPLACE_EXISTING) +} + +func syncDir(name string) error { return nil } diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go new file mode 100644 index 0000000000..fc1c8165df --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/mem_storage.go @@ -0,0 +1,203 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "bytes" + "os" + "sync" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +const typeShift = 3 + +type memStorageLock struct { + ms *memStorage +} + +func (lock *memStorageLock) Release() { + ms := lock.ms + ms.mu.Lock() + defer ms.mu.Unlock() + if ms.slock == lock { + ms.slock = nil + } + return +} + +// memStorage is a memory-backed storage. +type memStorage struct { + mu sync.Mutex + slock *memStorageLock + files map[uint64]*memFile + manifest *memFilePtr +} + +// NewMemStorage returns a new memory-backed storage implementation. +func NewMemStorage() Storage { + return &memStorage{ + files: make(map[uint64]*memFile), + } +} + +func (ms *memStorage) Lock() (util.Releaser, error) { + ms.mu.Lock() + defer ms.mu.Unlock() + if ms.slock != nil { + return nil, ErrLocked + } + ms.slock = &memStorageLock{ms: ms} + return ms.slock, nil +} + +func (*memStorage) Log(str string) {} + +func (ms *memStorage) GetFile(num uint64, t FileType) File { + return &memFilePtr{ms: ms, num: num, t: t} +} + +func (ms *memStorage) GetFiles(t FileType) ([]File, error) { + ms.mu.Lock() + var ff []File + for x, _ := range ms.files { + num, mt := x>>typeShift, FileType(x)&TypeAll + if mt&t == 0 { + continue + } + ff = append(ff, &memFilePtr{ms: ms, num: num, t: mt}) + } + ms.mu.Unlock() + return ff, nil +} + +func (ms *memStorage) GetManifest() (File, error) { + ms.mu.Lock() + defer ms.mu.Unlock() + if ms.manifest == nil { + return nil, os.ErrNotExist + } + return ms.manifest, nil +} + +func (ms *memStorage) SetManifest(f File) error { + fm, ok := f.(*memFilePtr) + if !ok || fm.t != TypeManifest { + return ErrInvalidFile + } + ms.mu.Lock() + ms.manifest = fm + ms.mu.Unlock() + return nil +} + +func (*memStorage) Close() error { return nil } + +type memReader struct { + *bytes.Reader + m *memFile +} + +func (mr *memReader) Close() error { + return mr.m.Close() +} + +type memFile struct { + bytes.Buffer + ms *memStorage + open bool +} + +func (*memFile) Sync() error { return nil } +func (m *memFile) Close() error { + m.ms.mu.Lock() + m.open = false + m.ms.mu.Unlock() + return nil +} + +type memFilePtr struct { + ms *memStorage + num uint64 + t FileType +} + +func (p *memFilePtr) x() uint64 { + return p.Num()< +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package storage + +import ( + "bytes" + "testing" +) + +func TestMemStorage(t *testing.T) { + m := NewMemStorage() + + l, err := m.Lock() + if err != nil { + t.Fatal("storage lock failed(1): ", err) + } + _, err = m.Lock() + if err == nil { + t.Fatal("expect error for second storage lock attempt") + } else { + t.Logf("storage lock got error: %s (expected)", err) + } + l.Release() + _, err = m.Lock() + if err != nil { + t.Fatal("storage lock failed(2): ", err) + } + + f := m.GetFile(1, TypeTable) + if f.Num() != 1 && f.Type() != TypeTable { + t.Fatal("invalid file number and type") + } + w, _ := f.Create() + w.Write([]byte("abc")) + w.Close() + if ff, _ := m.GetFiles(TypeAll); len(ff) != 1 { + t.Fatal("invalid GetFiles len") + } + buf := new(bytes.Buffer) + r, err := f.Open() + if err != nil { + t.Fatal("Open: got error: ", err) + } + buf.ReadFrom(r) + r.Close() + if got := buf.String(); got != "abc" { + t.Fatalf("Read: invalid value, want=abc got=%s", got) + } + if _, err := f.Open(); err != nil { + t.Fatal("Open: got error: ", err) + } + if _, err := m.GetFile(1, TypeTable).Open(); err == nil { + t.Fatal("expecting error") + } + f.Remove() + if ff, _ := m.GetFiles(TypeAll); len(ff) != 0 { + t.Fatal("invalid GetFiles len", len(ff)) + } + if _, err := f.Open(); err == nil { + t.Fatal("expecting error") + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/storage.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/storage.go new file mode 100644 index 0000000000..85dd70b06f --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage/storage.go @@ -0,0 +1,157 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package storage provides storage abstraction for LevelDB. +package storage + +import ( + "errors" + "fmt" + "io" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +type FileType uint32 + +const ( + TypeManifest FileType = 1 << iota + TypeJournal + TypeTable + TypeTemp + + TypeAll = TypeManifest | TypeJournal | TypeTable | TypeTemp +) + +func (t FileType) String() string { + switch t { + case TypeManifest: + return "manifest" + case TypeJournal: + return "journal" + case TypeTable: + return "table" + case TypeTemp: + return "temp" + } + return fmt.Sprintf("", t) +} + +var ( + ErrInvalidFile = errors.New("leveldb/storage: invalid file for argument") + ErrLocked = errors.New("leveldb/storage: already locked") + ErrClosed = errors.New("leveldb/storage: closed") +) + +// Syncer is the interface that wraps basic Sync method. +type Syncer interface { + // Sync commits the current contents of the file to stable storage. + Sync() error +} + +// Reader is the interface that groups the basic Read, Seek, ReadAt and Close +// methods. +type Reader interface { + io.ReadSeeker + io.ReaderAt + io.Closer +} + +// Writer is the interface that groups the basic Write, Sync and Close +// methods. +type Writer interface { + io.WriteCloser + Syncer +} + +// File is the file. A file instance must be goroutine-safe. +type File interface { + // Open opens the file for read. Returns os.ErrNotExist error + // if the file does not exist. + // Returns ErrClosed if the underlying storage is closed. + Open() (r Reader, err error) + + // Create creates the file for writting. Truncate the file if + // already exist. + // Returns ErrClosed if the underlying storage is closed. + Create() (w Writer, err error) + + // Replace replaces file with newfile. + // Returns ErrClosed if the underlying storage is closed. + Replace(newfile File) error + + // Type returns the file type + Type() FileType + + // Num returns the file number. + Num() uint64 + + // Remove removes the file. + // Returns ErrClosed if the underlying storage is closed. + Remove() error +} + +// Storage is the storage. A storage instance must be goroutine-safe. +type Storage interface { + // Lock locks the storage. Any subsequent attempt to call Lock will fail + // until the last lock released. + // After use the caller should call the Release method. + Lock() (l util.Releaser, err error) + + // Log logs a string. This is used for logging. An implementation + // may write to a file, stdout or simply do nothing. + Log(str string) + + // GetFile returns a file for the given number and type. GetFile will never + // returns nil, even if the underlying storage is closed. + GetFile(num uint64, t FileType) File + + // GetFiles returns a slice of files that match the given file types. + // The file types may be OR'ed together. + GetFiles(t FileType) ([]File, error) + + // GetManifest returns a manifest file. Returns os.ErrNotExist if manifest + // file does not exist. + GetManifest() (File, error) + + // SetManifest sets the given file as manifest file. The given file should + // be a manifest file type or error will be returned. + SetManifest(f File) error + + // Close closes the storage. It is valid to call Close multiple times. + // Other methods should not be called after the storage has been closed. + Close() error +} + +// FileInfo wraps basic file info. +type FileInfo struct { + Type FileType + Num uint64 +} + +func (fi FileInfo) String() string { + switch fi.Type { + case TypeManifest: + return fmt.Sprintf("MANIFEST-%06d", fi.Num) + case TypeJournal: + return fmt.Sprintf("%06d.log", fi.Num) + case TypeTable: + return fmt.Sprintf("%06d.ldb", fi.Num) + case TypeTemp: + return fmt.Sprintf("%06d.tmp", fi.Num) + default: + return fmt.Sprintf("%#x-%d", fi.Type, fi.Num) + } +} + +// NewFileInfo creates new FileInfo from the given File. It will returns nil +// if File is nil. +func NewFileInfo(f File) *FileInfo { + if f == nil { + return nil + } + return &FileInfo{f.Type(), f.Num()} +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage_test.go new file mode 100644 index 0000000000..dc1f1fb54b --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/storage_test.go @@ -0,0 +1,539 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENE file. + +package leveldb + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +const typeShift = 4 + +var ( + tsErrInvalidFile = errors.New("leveldb.testStorage: invalid file for argument") + tsErrFileOpen = errors.New("leveldb.testStorage: file still open") +) + +var ( + tsFSEnv = os.Getenv("GOLEVELDB_USEFS") + tsTempdir = os.Getenv("GOLEVELDB_TEMPDIR") + tsKeepFS = tsFSEnv == "2" + tsFS = tsKeepFS || tsFSEnv == "" || tsFSEnv == "1" + tsMU = &sync.Mutex{} + tsNum = 0 +) + +type tsOp uint + +const ( + tsOpOpen tsOp = iota + tsOpCreate + tsOpRead + tsOpReadAt + tsOpWrite + tsOpSync + + tsOpNum +) + +type tsLock struct { + ts *testStorage + r util.Releaser +} + +func (l tsLock) Release() { + l.r.Release() + l.ts.t.Log("I: storage lock released") +} + +type tsReader struct { + tf tsFile + storage.Reader +} + +func (tr tsReader) Read(b []byte) (n int, err error) { + ts := tr.tf.ts + ts.countRead(tr.tf.Type()) + if tr.tf.shouldErrLocked(tsOpRead) { + return 0, errors.New("leveldb.testStorage: emulated read error") + } + n, err = tr.Reader.Read(b) + if err != nil && err != io.EOF { + ts.t.Errorf("E: read error, num=%d type=%v n=%d: %v", tr.tf.Num(), tr.tf.Type(), n, err) + } + return +} + +func (tr tsReader) ReadAt(b []byte, off int64) (n int, err error) { + ts := tr.tf.ts + ts.countRead(tr.tf.Type()) + if tr.tf.shouldErrLocked(tsOpReadAt) { + return 0, errors.New("leveldb.testStorage: emulated readAt error") + } + n, err = tr.Reader.ReadAt(b, off) + if err != nil && err != io.EOF { + ts.t.Errorf("E: readAt error, num=%d type=%v off=%d n=%d: %v", tr.tf.Num(), tr.tf.Type(), off, n, err) + } + return +} + +func (tr tsReader) Close() (err error) { + err = tr.Reader.Close() + tr.tf.close("reader", err) + return +} + +type tsWriter struct { + tf tsFile + storage.Writer +} + +func (tw tsWriter) Write(b []byte) (n int, err error) { + if tw.tf.shouldErrLocked(tsOpWrite) { + return 0, errors.New("leveldb.testStorage: emulated write error") + } + n, err = tw.Writer.Write(b) + if err != nil { + tw.tf.ts.t.Errorf("E: write error, num=%d type=%v n=%d: %v", tw.tf.Num(), tw.tf.Type(), n, err) + } + return +} + +func (tw tsWriter) Sync() (err error) { + ts := tw.tf.ts + ts.mu.Lock() + for ts.emuDelaySync&tw.tf.Type() != 0 { + ts.cond.Wait() + } + ts.mu.Unlock() + if tw.tf.shouldErrLocked(tsOpSync) { + return errors.New("leveldb.testStorage: emulated sync error") + } + err = tw.Writer.Sync() + if err != nil { + tw.tf.ts.t.Errorf("E: sync error, num=%d type=%v: %v", tw.tf.Num(), tw.tf.Type(), err) + } + return +} + +func (tw tsWriter) Close() (err error) { + err = tw.Writer.Close() + tw.tf.close("writer", err) + return +} + +type tsFile struct { + ts *testStorage + storage.File +} + +func (tf tsFile) x() uint64 { + return tf.Num()<>typeShift, storage.FileType(x)&storage.TypeAll + ts.t.Errorf("E: * num=%d type=%v writer=%v", num, tt, writer) + } + } + ts.mu.Unlock() +} + +func newTestStorage(t *testing.T) *testStorage { + var stor storage.Storage + var closeFn func() error + if tsFS { + for { + tsMU.Lock() + num := tsNum + tsNum++ + tsMU.Unlock() + tempdir := tsTempdir + if tempdir == "" { + tempdir = os.TempDir() + } + path := filepath.Join(tempdir, fmt.Sprintf("goleveldb-test%d0%d0%d", os.Getuid(), os.Getpid(), num)) + if _, err := os.Stat(path); err != nil { + stor, err = storage.OpenFile(path) + if err != nil { + t.Fatalf("F: cannot create storage: %v", err) + } + t.Logf("I: storage created: %s", path) + closeFn = func() error { + for _, name := range []string{"LOG.old", "LOG"} { + f, err := os.Open(filepath.Join(path, name)) + if err != nil { + continue + } + if log, err := ioutil.ReadAll(f); err != nil { + t.Logf("---------------------- %s ----------------------", name) + t.Logf("cannot read log: %v", err) + t.Logf("---------------------- %s ----------------------", name) + } else if len(log) > 0 { + t.Logf("---------------------- %s ----------------------\n%s", name, string(log)) + t.Logf("---------------------- %s ----------------------", name) + } + f.Close() + } + if t.Failed() { + t.Logf("testing failed, test DB preserved at %s", path) + return nil + } + if tsKeepFS { + return nil + } + return os.RemoveAll(path) + } + + break + } + } + } else { + stor = storage.NewMemStorage() + } + ts := &testStorage{ + t: t, + Storage: stor, + closeFn: closeFn, + opens: make(map[uint64]bool), + emuErrOnceMap: make(map[uint64]uint), + emuRandErrProb: 0x999, + emuRandRand: rand.New(rand.NewSource(0xfacedead)), + } + ts.cond.L = &ts.mu + return ts +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table.go new file mode 100644 index 0000000000..3e8df6af51 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table.go @@ -0,0 +1,521 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "sort" + "sync/atomic" + + "github.com/syndtr/goleveldb/leveldb/cache" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/table" + "github.com/syndtr/goleveldb/leveldb/util" +) + +// tFile holds basic information about a table. +type tFile struct { + file storage.File + seekLeft int32 + size uint64 + imin, imax iKey +} + +// Returns true if given key is after largest key of this table. +func (t *tFile) after(icmp *iComparer, ukey []byte) bool { + return ukey != nil && icmp.uCompare(ukey, t.imax.ukey()) > 0 +} + +// Returns true if given key is before smallest key of this table. +func (t *tFile) before(icmp *iComparer, ukey []byte) bool { + return ukey != nil && icmp.uCompare(ukey, t.imin.ukey()) < 0 +} + +// Returns true if given key range overlaps with this table key range. +func (t *tFile) overlaps(icmp *iComparer, umin, umax []byte) bool { + return !t.after(icmp, umin) && !t.before(icmp, umax) +} + +// Cosumes one seek and return current seeks left. +func (t *tFile) consumeSeek() int32 { + return atomic.AddInt32(&t.seekLeft, -1) +} + +// Creates new tFile. +func newTableFile(file storage.File, size uint64, imin, imax iKey) *tFile { + f := &tFile{ + file: file, + size: size, + imin: imin, + imax: imax, + } + + // We arrange to automatically compact this file after + // a certain number of seeks. Let's assume: + // (1) One seek costs 10ms + // (2) Writing or reading 1MB costs 10ms (100MB/s) + // (3) A compaction of 1MB does 25MB of IO: + // 1MB read from this level + // 10-12MB read from next level (boundaries may be misaligned) + // 10-12MB written to next level + // This implies that 25 seeks cost the same as the compaction + // of 1MB of data. I.e., one seek costs approximately the + // same as the compaction of 40KB of data. We are a little + // conservative and allow approximately one seek for every 16KB + // of data before triggering a compaction. + f.seekLeft = int32(size / 16384) + if f.seekLeft < 100 { + f.seekLeft = 100 + } + + return f +} + +// tFiles hold multiple tFile. +type tFiles []*tFile + +func (tf tFiles) Len() int { return len(tf) } +func (tf tFiles) Swap(i, j int) { tf[i], tf[j] = tf[j], tf[i] } + +func (tf tFiles) nums() string { + x := "[ " + for i, f := range tf { + if i != 0 { + x += ", " + } + x += fmt.Sprint(f.file.Num()) + } + x += " ]" + return x +} + +// Returns true if i smallest key is less than j. +// This used for sort by key in ascending order. +func (tf tFiles) lessByKey(icmp *iComparer, i, j int) bool { + a, b := tf[i], tf[j] + n := icmp.Compare(a.imin, b.imin) + if n == 0 { + return a.file.Num() < b.file.Num() + } + return n < 0 +} + +// Returns true if i file number is greater than j. +// This used for sort by file number in descending order. +func (tf tFiles) lessByNum(i, j int) bool { + return tf[i].file.Num() > tf[j].file.Num() +} + +// Sorts tables by key in ascending order. +func (tf tFiles) sortByKey(icmp *iComparer) { + sort.Sort(&tFilesSortByKey{tFiles: tf, icmp: icmp}) +} + +// Sorts tables by file number in descending order. +func (tf tFiles) sortByNum() { + sort.Sort(&tFilesSortByNum{tFiles: tf}) +} + +// Returns sum of all tables size. +func (tf tFiles) size() (sum uint64) { + for _, t := range tf { + sum += t.size + } + return sum +} + +// Searches smallest index of tables whose its smallest +// key is after or equal with given key. +func (tf tFiles) searchMin(icmp *iComparer, ikey iKey) int { + return sort.Search(len(tf), func(i int) bool { + return icmp.Compare(tf[i].imin, ikey) >= 0 + }) +} + +// Searches smallest index of tables whose its largest +// key is after or equal with given key. +func (tf tFiles) searchMax(icmp *iComparer, ikey iKey) int { + return sort.Search(len(tf), func(i int) bool { + return icmp.Compare(tf[i].imax, ikey) >= 0 + }) +} + +// Returns true if given key range overlaps with one or more +// tables key range. If unsorted is true then binary search will not be used. +func (tf tFiles) overlaps(icmp *iComparer, umin, umax []byte, unsorted bool) bool { + if unsorted { + // Check against all files. + for _, t := range tf { + if t.overlaps(icmp, umin, umax) { + return true + } + } + return false + } + + i := 0 + if len(umin) > 0 { + // Find the earliest possible internal key for min. + i = tf.searchMax(icmp, newIkey(umin, kMaxSeq, ktSeek)) + } + if i >= len(tf) { + // Beginning of range is after all files, so no overlap. + return false + } + return !tf[i].before(icmp, umax) +} + +// Returns tables whose its key range overlaps with given key range. +// Range will be expanded if ukey found hop across tables. +// If overlapped is true then the search will be restarted if umax +// expanded. +// The dst content will be overwritten. +func (tf tFiles) getOverlaps(dst tFiles, icmp *iComparer, umin, umax []byte, overlapped bool) tFiles { + dst = dst[:0] + for i := 0; i < len(tf); { + t := tf[i] + if t.overlaps(icmp, umin, umax) { + if umin != nil && icmp.uCompare(t.imin.ukey(), umin) < 0 { + umin = t.imin.ukey() + dst = dst[:0] + i = 0 + continue + } else if umax != nil && icmp.uCompare(t.imax.ukey(), umax) > 0 { + umax = t.imax.ukey() + // Restart search if it is overlapped. + if overlapped { + dst = dst[:0] + i = 0 + continue + } + } + + dst = append(dst, t) + } + i++ + } + + return dst +} + +// Returns tables key range. +func (tf tFiles) getRange(icmp *iComparer) (imin, imax iKey) { + for i, t := range tf { + if i == 0 { + imin, imax = t.imin, t.imax + continue + } + if icmp.Compare(t.imin, imin) < 0 { + imin = t.imin + } + if icmp.Compare(t.imax, imax) > 0 { + imax = t.imax + } + } + + return +} + +// Creates iterator index from tables. +func (tf tFiles) newIndexIterator(tops *tOps, icmp *iComparer, slice *util.Range, ro *opt.ReadOptions) iterator.IteratorIndexer { + if slice != nil { + var start, limit int + if slice.Start != nil { + start = tf.searchMax(icmp, iKey(slice.Start)) + } + if slice.Limit != nil { + limit = tf.searchMin(icmp, iKey(slice.Limit)) + } else { + limit = tf.Len() + } + tf = tf[start:limit] + } + return iterator.NewArrayIndexer(&tFilesArrayIndexer{ + tFiles: tf, + tops: tops, + icmp: icmp, + slice: slice, + ro: ro, + }) +} + +// Tables iterator index. +type tFilesArrayIndexer struct { + tFiles + tops *tOps + icmp *iComparer + slice *util.Range + ro *opt.ReadOptions +} + +func (a *tFilesArrayIndexer) Search(key []byte) int { + return a.searchMax(a.icmp, iKey(key)) +} + +func (a *tFilesArrayIndexer) Get(i int) iterator.Iterator { + if i == 0 || i == a.Len()-1 { + return a.tops.newIterator(a.tFiles[i], a.slice, a.ro) + } + return a.tops.newIterator(a.tFiles[i], nil, a.ro) +} + +// Helper type for sortByKey. +type tFilesSortByKey struct { + tFiles + icmp *iComparer +} + +func (x *tFilesSortByKey) Less(i, j int) bool { + return x.lessByKey(x.icmp, i, j) +} + +// Helper type for sortByNum. +type tFilesSortByNum struct { + tFiles +} + +func (x *tFilesSortByNum) Less(i, j int) bool { + return x.lessByNum(i, j) +} + +// Table operations. +type tOps struct { + s *session + cache *cache.Cache + bcache *cache.Cache + bpool *util.BufferPool +} + +// Creates an empty table and returns table writer. +func (t *tOps) create() (*tWriter, error) { + file := t.s.getTableFile(t.s.allocFileNum()) + fw, err := file.Create() + if err != nil { + return nil, err + } + return &tWriter{ + t: t, + file: file, + w: fw, + tw: table.NewWriter(fw, t.s.o.Options), + }, nil +} + +// Builds table from src iterator. +func (t *tOps) createFrom(src iterator.Iterator) (f *tFile, n int, err error) { + w, err := t.create() + if err != nil { + return + } + + defer func() { + if err != nil { + w.drop() + } + }() + + for src.Next() { + err = w.append(src.Key(), src.Value()) + if err != nil { + return + } + } + err = src.Error() + if err != nil { + return + } + + n = w.tw.EntriesLen() + f, err = w.finish() + return +} + +// Opens table. It returns a cache handle, which should +// be released after use. +func (t *tOps) open(f *tFile) (ch *cache.Handle, err error) { + num := f.file.Num() + ch = t.cache.Get(0, num, func() (size int, value cache.Value) { + var r storage.Reader + r, err = f.file.Open() + if err != nil { + return 0, nil + } + + var bcache *cache.CacheGetter + if t.bcache != nil { + bcache = &cache.CacheGetter{Cache: t.bcache, NS: num} + } + + var tr *table.Reader + tr, err = table.NewReader(r, int64(f.size), storage.NewFileInfo(f.file), bcache, t.bpool, t.s.o.Options) + if err != nil { + r.Close() + return 0, nil + } + return 1, tr + + }) + if ch == nil && err == nil { + err = ErrClosed + } + return +} + +// Finds key/value pair whose key is greater than or equal to the +// given key. +func (t *tOps) find(f *tFile, key []byte, ro *opt.ReadOptions) (rkey, rvalue []byte, err error) { + ch, err := t.open(f) + if err != nil { + return nil, nil, err + } + defer ch.Release() + return ch.Value().(*table.Reader).Find(key, true, ro) +} + +// Finds key that is greater than or equal to the given key. +func (t *tOps) findKey(f *tFile, key []byte, ro *opt.ReadOptions) (rkey []byte, err error) { + ch, err := t.open(f) + if err != nil { + return nil, err + } + defer ch.Release() + return ch.Value().(*table.Reader).FindKey(key, true, ro) +} + +// Returns approximate offset of the given key. +func (t *tOps) offsetOf(f *tFile, key []byte) (offset uint64, err error) { + ch, err := t.open(f) + if err != nil { + return + } + defer ch.Release() + offset_, err := ch.Value().(*table.Reader).OffsetOf(key) + return uint64(offset_), err +} + +// Creates an iterator from the given table. +func (t *tOps) newIterator(f *tFile, slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + ch, err := t.open(f) + if err != nil { + return iterator.NewEmptyIterator(err) + } + iter := ch.Value().(*table.Reader).NewIterator(slice, ro) + iter.SetReleaser(ch) + return iter +} + +// Removes table from persistent storage. It waits until +// no one use the the table. +func (t *tOps) remove(f *tFile) { + num := f.file.Num() + t.cache.Delete(0, num, func() { + if err := f.file.Remove(); err != nil { + t.s.logf("table@remove removing @%d %q", num, err) + } else { + t.s.logf("table@remove removed @%d", num) + } + if t.bcache != nil { + t.bcache.EvictNS(num) + } + }) +} + +// Closes the table ops instance. It will close all tables, +// regadless still used or not. +func (t *tOps) close() { + t.bpool.Close() + t.cache.Close() + if t.bcache != nil { + t.bcache.Close() + } +} + +// Creates new initialized table ops instance. +func newTableOps(s *session) *tOps { + var ( + cacher cache.Cacher + bcache *cache.Cache + ) + if s.o.GetOpenFilesCacheCapacity() > 0 { + cacher = cache.NewLRU(s.o.GetOpenFilesCacheCapacity()) + } + if !s.o.DisableBlockCache { + var bcacher cache.Cacher + if s.o.GetBlockCacheCapacity() > 0 { + bcacher = cache.NewLRU(s.o.GetBlockCacheCapacity()) + } + bcache = cache.NewCache(bcacher) + } + return &tOps{ + s: s, + cache: cache.NewCache(cacher), + bcache: bcache, + bpool: util.NewBufferPool(s.o.GetBlockSize() + 5), + } +} + +// tWriter wraps the table writer. It keep track of file descriptor +// and added key range. +type tWriter struct { + t *tOps + + file storage.File + w storage.Writer + tw *table.Writer + + first, last []byte +} + +// Append key/value pair to the table. +func (w *tWriter) append(key, value []byte) error { + if w.first == nil { + w.first = append([]byte{}, key...) + } + w.last = append(w.last[:0], key...) + return w.tw.Append(key, value) +} + +// Returns true if the table is empty. +func (w *tWriter) empty() bool { + return w.first == nil +} + +// Closes the storage.Writer. +func (w *tWriter) close() { + if w.w != nil { + w.w.Close() + w.w = nil + } +} + +// Finalizes the table and returns table file. +func (w *tWriter) finish() (f *tFile, err error) { + defer w.close() + err = w.tw.Close() + if err != nil { + return + } + err = w.w.Sync() + if err != nil { + return + } + f = newTableFile(w.file, uint64(w.tw.BytesLen()), iKey(w.first), iKey(w.last)) + return +} + +// Drops the table. +func (w *tWriter) drop() { + w.close() + w.file.Remove() + w.t.s.reuseFileNum(w.file.Num()) + w.file = nil + w.tw = nil + w.first = nil + w.last = nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/block_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/block_test.go new file mode 100644 index 0000000000..00e6f9eea0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/block_test.go @@ -0,0 +1,139 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "encoding/binary" + "fmt" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type blockTesting struct { + tr *Reader + b *block +} + +func (t *blockTesting) TestNewIterator(slice *util.Range) iterator.Iterator { + return t.tr.newBlockIter(t.b, nil, slice, false) +} + +var _ = testutil.Defer(func() { + Describe("Block", func() { + Build := func(kv *testutil.KeyValue, restartInterval int) *blockTesting { + // Building the block. + bw := &blockWriter{ + restartInterval: restartInterval, + scratch: make([]byte, 30), + } + kv.Iterate(func(i int, key, value []byte) { + bw.append(key, value) + }) + bw.finish() + + // Opening the block. + data := bw.buf.Bytes() + restartsLen := int(binary.LittleEndian.Uint32(data[len(data)-4:])) + return &blockTesting{ + tr: &Reader{cmp: comparer.DefaultComparer}, + b: &block{ + data: data, + restartsLen: restartsLen, + restartsOffset: len(data) - (restartsLen+1)*4, + }, + } + } + + Describe("read test", func() { + for restartInterval := 1; restartInterval <= 5; restartInterval++ { + Describe(fmt.Sprintf("with restart interval of %d", restartInterval), func() { + kv := &testutil.KeyValue{} + Text := func() string { + return fmt.Sprintf("and %d keys", kv.Len()) + } + + Test := func() { + // Make block. + br := Build(kv, restartInterval) + // Do testing. + testutil.KeyValueTesting(nil, kv.Clone(), br, nil, nil) + } + + Describe(Text(), Test) + + kv.PutString("", "empty") + Describe(Text(), Test) + + kv.PutString("a1", "foo") + Describe(Text(), Test) + + kv.PutString("a2", "v") + Describe(Text(), Test) + + kv.PutString("a3qqwrkks", "hello") + Describe(Text(), Test) + + kv.PutString("a4", "bar") + Describe(Text(), Test) + + kv.PutString("a5111111", "v5") + kv.PutString("a6", "") + kv.PutString("a7", "v7") + kv.PutString("a8", "vvvvvvvvvvvvvvvvvvvvvv8") + kv.PutString("b", "v9") + kv.PutString("c9", "v9") + kv.PutString("c91", "v9") + kv.PutString("d0", "v9") + Describe(Text(), Test) + }) + } + }) + + Describe("out-of-bound slice test", func() { + kv := &testutil.KeyValue{} + kv.PutString("k1", "v1") + kv.PutString("k2", "v2") + kv.PutString("k3abcdefgg", "v3") + kv.PutString("k4", "v4") + kv.PutString("k5", "v5") + for restartInterval := 1; restartInterval <= 5; restartInterval++ { + Describe(fmt.Sprintf("with restart interval of %d", restartInterval), func() { + // Make block. + bt := Build(kv, restartInterval) + + Test := func(r *util.Range) func(done Done) { + return func(done Done) { + iter := bt.TestNewIterator(r) + Expect(iter.Error()).ShouldNot(HaveOccurred()) + + t := testutil.IteratorTesting{ + KeyValue: kv.Clone(), + Iter: iter, + } + + testutil.DoIteratorTesting(&t) + iter.Release() + done <- true + } + } + + It("Should do iterations and seeks correctly #0", + Test(&util.Range{Start: []byte("k0"), Limit: []byte("k6")}), 2.0) + + It("Should do iterations and seeks correctly #1", + Test(&util.Range{Start: []byte(""), Limit: []byte("zzzzzzz")}), 2.0) + }) + } + }) + }) +}) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/reader.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/reader.go new file mode 100644 index 0000000000..6f38e84b39 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/reader.go @@ -0,0 +1,1107 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "encoding/binary" + "fmt" + "io" + "sort" + "strings" + "sync" + + "github.com/syndtr/gosnappy/snappy" + + "github.com/syndtr/goleveldb/leveldb/cache" + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + ErrNotFound = errors.ErrNotFound + ErrReaderReleased = errors.New("leveldb/table: reader released") + ErrIterReleased = errors.New("leveldb/table: iterator released") +) + +type ErrCorrupted struct { + Pos int64 + Size int64 + Kind string + Reason string +} + +func (e *ErrCorrupted) Error() string { + return fmt.Sprintf("leveldb/table: corruption on %s (pos=%d): %s", e.Kind, e.Pos, e.Reason) +} + +func max(x, y int) int { + if x > y { + return x + } + return y +} + +type block struct { + bpool *util.BufferPool + bh blockHandle + data []byte + restartsLen int + restartsOffset int +} + +func (b *block) seek(cmp comparer.Comparer, rstart, rlimit int, key []byte) (index, offset int, err error) { + index = sort.Search(b.restartsLen-rstart-(b.restartsLen-rlimit), func(i int) bool { + offset := int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*(rstart+i):])) + offset += 1 // shared always zero, since this is a restart point + v1, n1 := binary.Uvarint(b.data[offset:]) // key length + _, n2 := binary.Uvarint(b.data[offset+n1:]) // value length + m := offset + n1 + n2 + return cmp.Compare(b.data[m:m+int(v1)], key) > 0 + }) + rstart - 1 + if index < rstart { + // The smallest key is greater-than key sought. + index = rstart + } + offset = int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*index:])) + return +} + +func (b *block) restartIndex(rstart, rlimit, offset int) int { + return sort.Search(b.restartsLen-rstart-(b.restartsLen-rlimit), func(i int) bool { + return int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*(rstart+i):])) > offset + }) + rstart - 1 +} + +func (b *block) restartOffset(index int) int { + return int(binary.LittleEndian.Uint32(b.data[b.restartsOffset+4*index:])) +} + +func (b *block) entry(offset int) (key, value []byte, nShared, n int, err error) { + if offset >= b.restartsOffset { + if offset != b.restartsOffset { + err = &ErrCorrupted{Reason: "entries offset not aligned"} + } + return + } + v0, n0 := binary.Uvarint(b.data[offset:]) // Shared prefix length + v1, n1 := binary.Uvarint(b.data[offset+n0:]) // Key length + v2, n2 := binary.Uvarint(b.data[offset+n0+n1:]) // Value length + m := n0 + n1 + n2 + n = m + int(v1) + int(v2) + if n0 <= 0 || n1 <= 0 || n2 <= 0 || offset+n > b.restartsOffset { + err = &ErrCorrupted{Reason: "entries corrupted"} + return + } + key = b.data[offset+m : offset+m+int(v1)] + value = b.data[offset+m+int(v1) : offset+n] + nShared = int(v0) + return +} + +func (b *block) Release() { + b.bpool.Put(b.data) + b.bpool = nil + b.data = nil +} + +type dir int + +const ( + dirReleased dir = iota - 1 + dirSOI + dirEOI + dirBackward + dirForward +) + +type blockIter struct { + tr *Reader + block *block + blockReleaser util.Releaser + releaser util.Releaser + key, value []byte + offset int + // Previous offset, only filled by Next. + prevOffset int + prevNode []int + prevKeys []byte + restartIndex int + // Iterator direction. + dir dir + // Restart index slice range. + riStart int + riLimit int + // Offset slice range. + offsetStart int + offsetRealStart int + offsetLimit int + // Error. + err error +} + +func (i *blockIter) sErr(err error) { + i.err = err + i.key = nil + i.value = nil + i.prevNode = nil + i.prevKeys = nil +} + +func (i *blockIter) reset() { + if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + i.restartIndex = i.riStart + i.offset = i.offsetStart + i.dir = dirSOI + i.key = i.key[:0] + i.value = nil +} + +func (i *blockIter) isFirst() bool { + switch i.dir { + case dirForward: + return i.prevOffset == i.offsetRealStart + case dirBackward: + return len(i.prevNode) == 1 && i.restartIndex == i.riStart + } + return false +} + +func (i *blockIter) isLast() bool { + switch i.dir { + case dirForward, dirBackward: + return i.offset == i.offsetLimit + } + return false +} + +func (i *blockIter) First() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + i.dir = dirSOI + return i.Next() +} + +func (i *blockIter) Last() bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + i.dir = dirEOI + return i.Prev() +} + +func (i *blockIter) Seek(key []byte) bool { + if i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + ri, offset, err := i.block.seek(i.tr.cmp, i.riStart, i.riLimit, key) + if err != nil { + i.sErr(err) + return false + } + i.restartIndex = ri + i.offset = max(i.offsetStart, offset) + if i.dir == dirSOI || i.dir == dirEOI { + i.dir = dirForward + } + for i.Next() { + if i.tr.cmp.Compare(i.key, key) >= 0 { + return true + } + } + return false +} + +func (i *blockIter) Next() bool { + if i.dir == dirEOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + if i.dir == dirSOI { + i.restartIndex = i.riStart + i.offset = i.offsetStart + } else if i.dir == dirBackward { + i.prevNode = i.prevNode[:0] + i.prevKeys = i.prevKeys[:0] + } + for i.offset < i.offsetRealStart { + key, value, nShared, n, err := i.block.entry(i.offset) + if err != nil { + i.sErr(i.tr.fixErrCorruptedBH(i.block.bh, err)) + return false + } + if n == 0 { + i.dir = dirEOI + return false + } + i.key = append(i.key[:nShared], key...) + i.value = value + i.offset += n + } + if i.offset >= i.offsetLimit { + i.dir = dirEOI + if i.offset != i.offsetLimit { + i.sErr(i.tr.newErrCorruptedBH(i.block.bh, "entries offset not aligned")) + } + return false + } + key, value, nShared, n, err := i.block.entry(i.offset) + if err != nil { + i.sErr(i.tr.fixErrCorruptedBH(i.block.bh, err)) + return false + } + if n == 0 { + i.dir = dirEOI + return false + } + i.key = append(i.key[:nShared], key...) + i.value = value + i.prevOffset = i.offset + i.offset += n + i.dir = dirForward + return true +} + +func (i *blockIter) Prev() bool { + if i.dir == dirSOI || i.err != nil { + return false + } else if i.dir == dirReleased { + i.err = ErrIterReleased + return false + } + + var ri int + if i.dir == dirForward { + // Change direction. + i.offset = i.prevOffset + if i.offset == i.offsetRealStart { + i.dir = dirSOI + return false + } + ri = i.block.restartIndex(i.restartIndex, i.riLimit, i.offset) + i.dir = dirBackward + } else if i.dir == dirEOI { + // At the end of iterator. + i.restartIndex = i.riLimit + i.offset = i.offsetLimit + if i.offset == i.offsetRealStart { + i.dir = dirSOI + return false + } + ri = i.riLimit - 1 + i.dir = dirBackward + } else if len(i.prevNode) == 1 { + // This is the end of a restart range. + i.offset = i.prevNode[0] + i.prevNode = i.prevNode[:0] + if i.restartIndex == i.riStart { + i.dir = dirSOI + return false + } + i.restartIndex-- + ri = i.restartIndex + } else { + // In the middle of restart range, get from cache. + n := len(i.prevNode) - 3 + node := i.prevNode[n:] + i.prevNode = i.prevNode[:n] + // Get the key. + ko := node[0] + i.key = append(i.key[:0], i.prevKeys[ko:]...) + i.prevKeys = i.prevKeys[:ko] + // Get the value. + vo := node[1] + vl := vo + node[2] + i.value = i.block.data[vo:vl] + i.offset = vl + return true + } + // Build entries cache. + i.key = i.key[:0] + i.value = nil + offset := i.block.restartOffset(ri) + if offset == i.offset { + ri -= 1 + if ri < 0 { + i.dir = dirSOI + return false + } + offset = i.block.restartOffset(ri) + } + i.prevNode = append(i.prevNode, offset) + for { + key, value, nShared, n, err := i.block.entry(offset) + if err != nil { + i.sErr(i.tr.fixErrCorruptedBH(i.block.bh, err)) + return false + } + if offset >= i.offsetRealStart { + if i.value != nil { + // Appends 3 variables: + // 1. Previous keys offset + // 2. Value offset in the data block + // 3. Value length + i.prevNode = append(i.prevNode, len(i.prevKeys), offset-len(i.value), len(i.value)) + i.prevKeys = append(i.prevKeys, i.key...) + } + i.value = value + } + i.key = append(i.key[:nShared], key...) + offset += n + // Stop if target offset reached. + if offset >= i.offset { + if offset != i.offset { + i.sErr(i.tr.newErrCorruptedBH(i.block.bh, "entries offset not aligned")) + return false + } + + break + } + } + i.restartIndex = ri + i.offset = offset + return true +} + +func (i *blockIter) Key() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.key +} + +func (i *blockIter) Value() []byte { + if i.err != nil || i.dir <= dirEOI { + return nil + } + return i.value +} + +func (i *blockIter) Release() { + if i.dir != dirReleased { + i.tr = nil + i.block = nil + i.prevNode = nil + i.prevKeys = nil + i.key = nil + i.value = nil + i.dir = dirReleased + if i.blockReleaser != nil { + i.blockReleaser.Release() + i.blockReleaser = nil + } + if i.releaser != nil { + i.releaser.Release() + i.releaser = nil + } + } +} + +func (i *blockIter) SetReleaser(releaser util.Releaser) { + if i.dir == dirReleased { + panic(util.ErrReleased) + } + if i.releaser != nil && releaser != nil { + panic(util.ErrHasReleaser) + } + i.releaser = releaser +} + +func (i *blockIter) Valid() bool { + return i.err == nil && (i.dir == dirBackward || i.dir == dirForward) +} + +func (i *blockIter) Error() error { + return i.err +} + +type filterBlock struct { + bpool *util.BufferPool + data []byte + oOffset int + baseLg uint + filtersNum int +} + +func (b *filterBlock) contains(filter filter.Filter, offset uint64, key []byte) bool { + i := int(offset >> b.baseLg) + if i < b.filtersNum { + o := b.data[b.oOffset+i*4:] + n := int(binary.LittleEndian.Uint32(o)) + m := int(binary.LittleEndian.Uint32(o[4:])) + if n < m && m <= b.oOffset { + return filter.Contains(b.data[n:m], key) + } else if n == m { + return false + } + } + return true +} + +func (b *filterBlock) Release() { + b.bpool.Put(b.data) + b.bpool = nil + b.data = nil +} + +type indexIter struct { + *blockIter + tr *Reader + slice *util.Range + // Options + fillCache bool +} + +func (i *indexIter) Get() iterator.Iterator { + value := i.Value() + if value == nil { + return nil + } + dataBH, n := decodeBlockHandle(value) + if n == 0 { + return iterator.NewEmptyIterator(i.tr.newErrCorruptedBH(i.tr.indexBH, "bad data block handle")) + } + + var slice *util.Range + if i.slice != nil && (i.blockIter.isFirst() || i.blockIter.isLast()) { + slice = i.slice + } + return i.tr.getDataIterErr(dataBH, slice, i.tr.verifyChecksum, i.fillCache) +} + +// Reader is a table reader. +type Reader struct { + mu sync.RWMutex + fi *storage.FileInfo + reader io.ReaderAt + cache *cache.CacheGetter + err error + bpool *util.BufferPool + // Options + o *opt.Options + cmp comparer.Comparer + filter filter.Filter + verifyChecksum bool + + dataEnd int64 + metaBH, indexBH, filterBH blockHandle + indexBlock *block + filterBlock *filterBlock +} + +func (r *Reader) blockKind(bh blockHandle) string { + switch bh.offset { + case r.metaBH.offset: + return "meta-block" + case r.indexBH.offset: + return "index-block" + case r.filterBH.offset: + if r.filterBH.length > 0 { + return "filter-block" + } + } + return "data-block" +} + +func (r *Reader) newErrCorrupted(pos, size int64, kind, reason string) error { + return &errors.ErrCorrupted{File: r.fi, Err: &ErrCorrupted{Pos: pos, Size: size, Kind: kind, Reason: reason}} +} + +func (r *Reader) newErrCorruptedBH(bh blockHandle, reason string) error { + return r.newErrCorrupted(int64(bh.offset), int64(bh.length), r.blockKind(bh), reason) +} + +func (r *Reader) fixErrCorruptedBH(bh blockHandle, err error) error { + if cerr, ok := err.(*ErrCorrupted); ok { + cerr.Pos = int64(bh.offset) + cerr.Size = int64(bh.length) + cerr.Kind = r.blockKind(bh) + return &errors.ErrCorrupted{File: r.fi, Err: cerr} + } + return err +} + +func (r *Reader) readRawBlock(bh blockHandle, verifyChecksum bool) ([]byte, error) { + data := r.bpool.Get(int(bh.length + blockTrailerLen)) + if _, err := r.reader.ReadAt(data, int64(bh.offset)); err != nil && err != io.EOF { + return nil, err + } + + if verifyChecksum { + n := bh.length + 1 + checksum0 := binary.LittleEndian.Uint32(data[n:]) + checksum1 := util.NewCRC(data[:n]).Value() + if checksum0 != checksum1 { + r.bpool.Put(data) + return nil, r.newErrCorruptedBH(bh, fmt.Sprintf("checksum mismatch, want=%#x got=%#x", checksum0, checksum1)) + } + } + + switch data[bh.length] { + case blockTypeNoCompression: + data = data[:bh.length] + case blockTypeSnappyCompression: + decLen, err := snappy.DecodedLen(data[:bh.length]) + if err != nil { + return nil, r.newErrCorruptedBH(bh, err.Error()) + } + decData := r.bpool.Get(decLen) + decData, err = snappy.Decode(decData, data[:bh.length]) + r.bpool.Put(data) + if err != nil { + r.bpool.Put(decData) + return nil, r.newErrCorruptedBH(bh, err.Error()) + } + data = decData + default: + r.bpool.Put(data) + return nil, r.newErrCorruptedBH(bh, fmt.Sprintf("unknown compression type %#x", data[bh.length])) + } + return data, nil +} + +func (r *Reader) readBlock(bh blockHandle, verifyChecksum bool) (*block, error) { + data, err := r.readRawBlock(bh, verifyChecksum) + if err != nil { + return nil, err + } + restartsLen := int(binary.LittleEndian.Uint32(data[len(data)-4:])) + b := &block{ + bpool: r.bpool, + bh: bh, + data: data, + restartsLen: restartsLen, + restartsOffset: len(data) - (restartsLen+1)*4, + } + return b, nil +} + +func (r *Reader) readBlockCached(bh blockHandle, verifyChecksum, fillCache bool) (*block, util.Releaser, error) { + if r.cache != nil { + var ( + err error + ch *cache.Handle + ) + if fillCache { + ch = r.cache.Get(bh.offset, func() (size int, value cache.Value) { + var b *block + b, err = r.readBlock(bh, verifyChecksum) + if err != nil { + return 0, nil + } + return cap(b.data), b + }) + } else { + ch = r.cache.Get(bh.offset, nil) + } + if ch != nil { + b, ok := ch.Value().(*block) + if !ok { + ch.Release() + return nil, nil, errors.New("leveldb/table: inconsistent block type") + } + return b, ch, err + } else if err != nil { + return nil, nil, err + } + } + + b, err := r.readBlock(bh, verifyChecksum) + return b, b, err +} + +func (r *Reader) readFilterBlock(bh blockHandle) (*filterBlock, error) { + data, err := r.readRawBlock(bh, true) + if err != nil { + return nil, err + } + n := len(data) + if n < 5 { + return nil, r.newErrCorruptedBH(bh, "too short") + } + m := n - 5 + oOffset := int(binary.LittleEndian.Uint32(data[m:])) + if oOffset > m { + return nil, r.newErrCorruptedBH(bh, "invalid data-offsets offset") + } + b := &filterBlock{ + bpool: r.bpool, + data: data, + oOffset: oOffset, + baseLg: uint(data[n-1]), + filtersNum: (m - oOffset) / 4, + } + return b, nil +} + +func (r *Reader) readFilterBlockCached(bh blockHandle, fillCache bool) (*filterBlock, util.Releaser, error) { + if r.cache != nil { + var ( + err error + ch *cache.Handle + ) + if fillCache { + ch = r.cache.Get(bh.offset, func() (size int, value cache.Value) { + var b *filterBlock + b, err = r.readFilterBlock(bh) + if err != nil { + return 0, nil + } + return cap(b.data), b + }) + } else { + ch = r.cache.Get(bh.offset, nil) + } + if ch != nil { + b, ok := ch.Value().(*filterBlock) + if !ok { + ch.Release() + return nil, nil, errors.New("leveldb/table: inconsistent block type") + } + return b, ch, err + } else if err != nil { + return nil, nil, err + } + } + + b, err := r.readFilterBlock(bh) + return b, b, err +} + +func (r *Reader) getIndexBlock(fillCache bool) (b *block, rel util.Releaser, err error) { + if r.indexBlock == nil { + return r.readBlockCached(r.indexBH, true, fillCache) + } + return r.indexBlock, util.NoopReleaser{}, nil +} + +func (r *Reader) getFilterBlock(fillCache bool) (*filterBlock, util.Releaser, error) { + if r.filterBlock == nil { + return r.readFilterBlockCached(r.filterBH, fillCache) + } + return r.filterBlock, util.NoopReleaser{}, nil +} + +func (r *Reader) newBlockIter(b *block, bReleaser util.Releaser, slice *util.Range, inclLimit bool) *blockIter { + bi := &blockIter{ + tr: r, + block: b, + blockReleaser: bReleaser, + // Valid key should never be nil. + key: make([]byte, 0), + dir: dirSOI, + riStart: 0, + riLimit: b.restartsLen, + offsetStart: 0, + offsetRealStart: 0, + offsetLimit: b.restartsOffset, + } + if slice != nil { + if slice.Start != nil { + if bi.Seek(slice.Start) { + bi.riStart = b.restartIndex(bi.restartIndex, b.restartsLen, bi.prevOffset) + bi.offsetStart = b.restartOffset(bi.riStart) + bi.offsetRealStart = bi.prevOffset + } else { + bi.riStart = b.restartsLen + bi.offsetStart = b.restartsOffset + bi.offsetRealStart = b.restartsOffset + } + } + if slice.Limit != nil { + if bi.Seek(slice.Limit) && (!inclLimit || bi.Next()) { + bi.offsetLimit = bi.prevOffset + bi.riLimit = bi.restartIndex + 1 + } + } + bi.reset() + if bi.offsetStart > bi.offsetLimit { + bi.sErr(errors.New("leveldb/table: invalid slice range")) + } + } + return bi +} + +func (r *Reader) getDataIter(dataBH blockHandle, slice *util.Range, verifyChecksum, fillCache bool) iterator.Iterator { + b, rel, err := r.readBlockCached(dataBH, verifyChecksum, fillCache) + if err != nil { + return iterator.NewEmptyIterator(err) + } + return r.newBlockIter(b, rel, slice, false) +} + +func (r *Reader) getDataIterErr(dataBH blockHandle, slice *util.Range, verifyChecksum, fillCache bool) iterator.Iterator { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + return iterator.NewEmptyIterator(r.err) + } + + return r.getDataIter(dataBH, slice, verifyChecksum, fillCache) +} + +// NewIterator creates an iterator from the table. +// +// Slice allows slicing the iterator to only contains keys in the given +// range. A nil Range.Start is treated as a key before all keys in the +// table. And a nil Range.Limit is treated as a key after all keys in +// the table. +// +// The returned iterator is not goroutine-safe and should be released +// when not used. +// +// Also read Iterator documentation of the leveldb/iterator package. +func (r *Reader) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + return iterator.NewEmptyIterator(r.err) + } + + fillCache := !ro.GetDontFillCache() + indexBlock, rel, err := r.getIndexBlock(fillCache) + if err != nil { + return iterator.NewEmptyIterator(err) + } + index := &indexIter{ + blockIter: r.newBlockIter(indexBlock, rel, slice, true), + tr: r, + slice: slice, + fillCache: !ro.GetDontFillCache(), + } + return iterator.NewIndexedIterator(index, opt.GetStrict(r.o, ro, opt.StrictReader)) +} + +func (r *Reader) find(key []byte, filtered bool, ro *opt.ReadOptions, noValue bool) (rkey, value []byte, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + err = r.err + return + } + + indexBlock, rel, err := r.getIndexBlock(true) + if err != nil { + return + } + defer rel.Release() + + index := r.newBlockIter(indexBlock, nil, nil, true) + defer index.Release() + if !index.Seek(key) { + err = index.Error() + if err == nil { + err = ErrNotFound + } + return + } + dataBH, n := decodeBlockHandle(index.Value()) + if n == 0 { + r.err = r.newErrCorruptedBH(r.indexBH, "bad data block handle") + return + } + if filtered && r.filter != nil { + filterBlock, frel, ferr := r.getFilterBlock(true) + if ferr == nil { + if !filterBlock.contains(r.filter, dataBH.offset, key) { + frel.Release() + return nil, nil, ErrNotFound + } + frel.Release() + } else if !errors.IsCorrupted(ferr) { + err = ferr + return + } + } + data := r.getDataIter(dataBH, nil, r.verifyChecksum, !ro.GetDontFillCache()) + defer data.Release() + if !data.Seek(key) { + err = data.Error() + if err == nil { + err = ErrNotFound + } + return + } + // Don't use block buffer, no need to copy the buffer. + rkey = data.Key() + if !noValue { + if r.bpool == nil { + value = data.Value() + } else { + // Use block buffer, and since the buffer will be recycled, the buffer + // need to be copied. + value = append([]byte{}, data.Value()...) + } + } + return +} + +// Find finds key/value pair whose key is greater than or equal to the +// given key. It returns ErrNotFound if the table doesn't contain +// such pair. +// If filtered is true then the nearest 'block' will be checked against +// 'filter data' (if present) and will immediately return ErrNotFound if +// 'filter data' indicates that such pair doesn't exist. +// +// The caller may modify the contents of the returned slice as it is its +// own copy. +// It is safe to modify the contents of the argument after Find returns. +func (r *Reader) Find(key []byte, filtered bool, ro *opt.ReadOptions) (rkey, value []byte, err error) { + return r.find(key, filtered, ro, false) +} + +// Find finds key that is greater than or equal to the given key. +// It returns ErrNotFound if the table doesn't contain such key. +// If filtered is true then the nearest 'block' will be checked against +// 'filter data' (if present) and will immediately return ErrNotFound if +// 'filter data' indicates that such key doesn't exist. +// +// The caller may modify the contents of the returned slice as it is its +// own copy. +// It is safe to modify the contents of the argument after Find returns. +func (r *Reader) FindKey(key []byte, filtered bool, ro *opt.ReadOptions) (rkey []byte, err error) { + rkey, _, err = r.find(key, filtered, ro, true) + return +} + +// Get gets the value for the given key. It returns errors.ErrNotFound +// if the table does not contain the key. +// +// The caller may modify the contents of the returned slice as it is its +// own copy. +// It is safe to modify the contents of the argument after Find returns. +func (r *Reader) Get(key []byte, ro *opt.ReadOptions) (value []byte, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + err = r.err + return + } + + rkey, value, err := r.find(key, false, ro, false) + if err == nil && r.cmp.Compare(rkey, key) != 0 { + value = nil + err = ErrNotFound + } + return +} + +// OffsetOf returns approximate offset for the given key. +// +// It is safe to modify the contents of the argument after Get returns. +func (r *Reader) OffsetOf(key []byte) (offset int64, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.err != nil { + err = r.err + return + } + + indexBlock, rel, err := r.readBlockCached(r.indexBH, true, true) + if err != nil { + return + } + defer rel.Release() + + index := r.newBlockIter(indexBlock, nil, nil, true) + defer index.Release() + if index.Seek(key) { + dataBH, n := decodeBlockHandle(index.Value()) + if n == 0 { + r.err = r.newErrCorruptedBH(r.indexBH, "bad data block handle") + return + } + offset = int64(dataBH.offset) + return + } + err = index.Error() + if err == nil { + offset = r.dataEnd + } + return +} + +// Release implements util.Releaser. +// It also close the file if it is an io.Closer. +func (r *Reader) Release() { + r.mu.Lock() + defer r.mu.Unlock() + + if closer, ok := r.reader.(io.Closer); ok { + closer.Close() + } + if r.indexBlock != nil { + r.indexBlock.Release() + r.indexBlock = nil + } + if r.filterBlock != nil { + r.filterBlock.Release() + r.filterBlock = nil + } + r.reader = nil + r.cache = nil + r.bpool = nil + r.err = ErrReaderReleased +} + +// NewReader creates a new initialized table reader for the file. +// The fi, cache and bpool is optional and can be nil. +// +// The returned table reader instance is goroutine-safe. +func NewReader(f io.ReaderAt, size int64, fi *storage.FileInfo, cache *cache.CacheGetter, bpool *util.BufferPool, o *opt.Options) (*Reader, error) { + if f == nil { + return nil, errors.New("leveldb/table: nil file") + } + + r := &Reader{ + fi: fi, + reader: f, + cache: cache, + bpool: bpool, + o: o, + cmp: o.GetComparer(), + verifyChecksum: o.GetStrict(opt.StrictBlockChecksum), + } + + if size < footerLen { + r.err = r.newErrCorrupted(0, size, "table", "too small") + return r, nil + } + + footerPos := size - footerLen + var footer [footerLen]byte + if _, err := r.reader.ReadAt(footer[:], footerPos); err != nil && err != io.EOF { + return nil, err + } + if string(footer[footerLen-len(magic):footerLen]) != magic { + r.err = r.newErrCorrupted(footerPos, footerLen, "table-footer", "bad magic number") + return r, nil + } + + var n int + // Decode the metaindex block handle. + r.metaBH, n = decodeBlockHandle(footer[:]) + if n == 0 { + r.err = r.newErrCorrupted(footerPos, footerLen, "table-footer", "bad metaindex block handle") + return r, nil + } + + // Decode the index block handle. + r.indexBH, n = decodeBlockHandle(footer[n:]) + if n == 0 { + r.err = r.newErrCorrupted(footerPos, footerLen, "table-footer", "bad index block handle") + return r, nil + } + + // Read metaindex block. + metaBlock, err := r.readBlock(r.metaBH, true) + if err != nil { + if errors.IsCorrupted(err) { + r.err = err + return r, nil + } else { + return nil, err + } + } + + // Set data end. + r.dataEnd = int64(r.metaBH.offset) + + // Read metaindex. + metaIter := r.newBlockIter(metaBlock, nil, nil, true) + for metaIter.Next() { + key := string(metaIter.Key()) + if !strings.HasPrefix(key, "filter.") { + continue + } + fn := key[7:] + if f0 := o.GetFilter(); f0 != nil && f0.Name() == fn { + r.filter = f0 + } else { + for _, f0 := range o.GetAltFilters() { + if f0.Name() == fn { + r.filter = f0 + break + } + } + } + if r.filter != nil { + filterBH, n := decodeBlockHandle(metaIter.Value()) + if n == 0 { + continue + } + r.filterBH = filterBH + // Update data end. + r.dataEnd = int64(filterBH.offset) + break + } + } + metaIter.Release() + metaBlock.Release() + + // Cache index and filter block locally, since we don't have global cache. + if cache == nil { + r.indexBlock, err = r.readBlock(r.indexBH, true) + if err != nil { + if errors.IsCorrupted(err) { + r.err = err + return r, nil + } else { + return nil, err + } + } + if r.filter != nil { + r.filterBlock, err = r.readFilterBlock(r.filterBH) + if err != nil { + if !errors.IsCorrupted(err) { + return nil, err + } + + // Don't use filter then. + r.filter = nil + } + } + } + + return r, nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table.go new file mode 100644 index 0000000000..beacdc1f02 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table.go @@ -0,0 +1,177 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package table allows read and write sorted key/value. +package table + +import ( + "encoding/binary" +) + +/* +Table: + +Table is consist of one or more data blocks, an optional filter block +a metaindex block, an index block and a table footer. Metaindex block +is a special block used to keep parameters of the table, such as filter +block name and its block handle. Index block is a special block used to +keep record of data blocks offset and length, index block use one as +restart interval. The key used by index block are the last key of preceding +block, shorter separator of adjacent blocks or shorter successor of the +last key of the last block. Filter block is an optional block contains +sequence of filter data generated by a filter generator. + +Table data structure: + + optional + / + +--------------+--------------+--------------+------+-------+-----------------+-------------+--------+ + | data block 1 | ... | data block n | filter block | metaindex block | index block | footer | + +--------------+--------------+--------------+--------------+-----------------+-------------+--------+ + + Each block followed by a 5-bytes trailer contains compression type and checksum. + +Table block trailer: + + +---------------------------+-------------------+ + | compression type (1-byte) | checksum (4-byte) | + +---------------------------+-------------------+ + + The checksum is a CRC-32 computed using Castagnoli's polynomial. Compression + type also included in the checksum. + +Table footer: + + +------------------- 40-bytes -------------------+ + / \ + +------------------------+--------------------+------+-----------------+ + | metaindex block handle / index block handle / ---- | magic (8-bytes) | + +------------------------+--------------------+------+-----------------+ + + The magic are first 64-bit of SHA-1 sum of "http://code.google.com/p/leveldb/". + +NOTE: All fixed-length integer are little-endian. +*/ + +/* +Block: + +Block is consist of one or more key/value entries and a block trailer. +Block entry shares key prefix with its preceding key until a restart +point reached. A block should contains at least one restart point. +First restart point are always zero. + +Block data structure: + + + restart point + restart point (depends on restart interval) + / / + +---------------+---------------+---------------+---------------+---------+ + | block entry 1 | block entry 2 | ... | block entry n | trailer | + +---------------+---------------+---------------+---------------+---------+ + +Key/value entry: + + +---- key len ----+ + / \ + +-------+---------+-----------+---------+--------------------+--------------+----------------+ + | shared (varint) | not shared (varint) | value len (varint) | key (varlen) | value (varlen) | + +-----------------+---------------------+--------------------+--------------+----------------+ + + Block entry shares key prefix with its preceding key: + Conditions: + restart_interval=2 + entry one : key=deck,value=v1 + entry two : key=dock,value=v2 + entry three: key=duck,value=v3 + The entries will be encoded as follow: + + + restart point (offset=0) + restart point (offset=16) + / / + +-----+-----+-----+----------+--------+-----+-----+-----+---------+--------+-----+-----+-----+----------+--------+ + | 0 | 4 | 2 | "deck" | "v1" | 1 | 3 | 2 | "ock" | "v2" | 0 | 4 | 2 | "duck" | "v3" | + +-----+-----+-----+----------+--------+-----+-----+-----+---------+--------+-----+-----+-----+----------+--------+ + \ / \ / \ / + +----------- entry one -----------+ +----------- entry two ----------+ +---------- entry three ----------+ + + The block trailer will contains two restart points: + + +------------+-----------+--------+ + | 0 | 16 | 2 | + +------------+-----------+---+----+ + \ / \ + +-- restart points --+ + restart points length + +Block trailer: + + +-- 4-bytes --+ + / \ + +-----------------+-----------------+-----------------+------------------------------+ + | restart point 1 | .... | restart point n | restart points len (4-bytes) | + +-----------------+-----------------+-----------------+------------------------------+ + + +NOTE: All fixed-length integer are little-endian. +*/ + +/* +Filter block: + +Filter block consist of one or more filter data and a filter block trailer. +The trailer contains filter data offsets, a trailer offset and a 1-byte base Lg. + +Filter block data structure: + + + offset 1 + offset 2 + offset n + trailer offset + / / / / + +---------------+---------------+---------------+---------+ + | filter data 1 | ... | filter data n | trailer | + +---------------+---------------+---------------+---------+ + +Filter block trailer: + + +- 4-bytes -+ + / \ + +---------------+---------------+---------------+-------------------------------+------------------+ + | data 1 offset | .... | data n offset | data-offsets offset (4-bytes) | base Lg (1-byte) | + +-------------- +---------------+---------------+-------------------------------+------------------+ + + +NOTE: All fixed-length integer are little-endian. +*/ + +const ( + blockTrailerLen = 5 + footerLen = 48 + + magic = "\x57\xfb\x80\x8b\x24\x75\x47\xdb" + + // The block type gives the per-block compression format. + // These constants are part of the file format and should not be changed. + blockTypeNoCompression = 0 + blockTypeSnappyCompression = 1 + + // Generate new filter every 2KB of data + filterBaseLg = 11 + filterBase = 1 << filterBaseLg +) + +type blockHandle struct { + offset, length uint64 +} + +func decodeBlockHandle(src []byte) (blockHandle, int) { + offset, n := binary.Uvarint(src) + length, m := binary.Uvarint(src[n:]) + if n == 0 || m == 0 { + return blockHandle{}, 0 + } + return blockHandle{offset, length}, n + m +} + +func encodeBlockHandle(dst []byte, b blockHandle) int { + n := binary.PutUvarint(dst, b.offset) + m := binary.PutUvarint(dst[n:], b.length) + return n + m +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go new file mode 100644 index 0000000000..6465da6e37 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_suite_test.go @@ -0,0 +1,11 @@ +package table + +import ( + "testing" + + "github.com/syndtr/goleveldb/leveldb/testutil" +) + +func TestTable(t *testing.T) { + testutil.RunSuite(t, "Table Suite") +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_test.go new file mode 100644 index 0000000000..4b59b31f52 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/table_test.go @@ -0,0 +1,122 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type tableWrapper struct { + *Reader +} + +func (t tableWrapper) TestFind(key []byte) (rkey, rvalue []byte, err error) { + return t.Reader.Find(key, false, nil) +} + +func (t tableWrapper) TestGet(key []byte) (value []byte, err error) { + return t.Reader.Get(key, nil) +} + +func (t tableWrapper) TestNewIterator(slice *util.Range) iterator.Iterator { + return t.Reader.NewIterator(slice, nil) +} + +var _ = testutil.Defer(func() { + Describe("Table", func() { + Describe("approximate offset test", func() { + var ( + buf = &bytes.Buffer{} + o = &opt.Options{ + BlockSize: 1024, + Compression: opt.NoCompression, + } + ) + + // Building the table. + tw := NewWriter(buf, o) + tw.Append([]byte("k01"), []byte("hello")) + tw.Append([]byte("k02"), []byte("hello2")) + tw.Append([]byte("k03"), bytes.Repeat([]byte{'x'}, 10000)) + tw.Append([]byte("k04"), bytes.Repeat([]byte{'x'}, 200000)) + tw.Append([]byte("k05"), bytes.Repeat([]byte{'x'}, 300000)) + tw.Append([]byte("k06"), []byte("hello3")) + tw.Append([]byte("k07"), bytes.Repeat([]byte{'x'}, 100000)) + err := tw.Close() + + It("Should be able to approximate offset of a key correctly", func() { + Expect(err).ShouldNot(HaveOccurred()) + + tr, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()), nil, nil, nil, o) + Expect(err).ShouldNot(HaveOccurred()) + CheckOffset := func(key string, expect, threshold int) { + offset, err := tr.OffsetOf([]byte(key)) + Expect(err).ShouldNot(HaveOccurred()) + Expect(offset).Should(BeNumerically("~", expect, threshold), "Offset of key %q", key) + } + + CheckOffset("k0", 0, 0) + CheckOffset("k01a", 0, 0) + CheckOffset("k02", 0, 0) + CheckOffset("k03", 0, 0) + CheckOffset("k04", 10000, 1000) + CheckOffset("k04a", 210000, 1000) + CheckOffset("k05", 210000, 1000) + CheckOffset("k06", 510000, 1000) + CheckOffset("k07", 510000, 1000) + CheckOffset("xyz", 610000, 2000) + }) + }) + + Describe("read test", func() { + Build := func(kv testutil.KeyValue) testutil.DB { + o := &opt.Options{ + BlockSize: 512, + BlockRestartInterval: 3, + } + buf := &bytes.Buffer{} + + // Building the table. + tw := NewWriter(buf, o) + kv.Iterate(func(i int, key, value []byte) { + tw.Append(key, value) + }) + tw.Close() + + // Opening the table. + tr, _ := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()), nil, nil, nil, o) + return tableWrapper{tr} + } + Test := func(kv *testutil.KeyValue, body func(r *Reader)) func() { + return func() { + db := Build(*kv) + if body != nil { + body(db.(tableWrapper).Reader) + } + testutil.KeyValueTesting(nil, *kv, db, nil, nil) + } + } + + testutil.AllKeyValueTesting(nil, Build, nil, nil) + Describe("with one key per block", Test(testutil.KeyValue_Generate(nil, 9, 1, 10, 512, 512), func(r *Reader) { + It("should have correct blocks number", func() { + indexBlock, err := r.readBlock(r.indexBH, true) + Expect(err).To(BeNil()) + Expect(indexBlock.restartsLen).Should(Equal(9)) + }) + })) + }) + }) +}) diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/writer.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/writer.go new file mode 100644 index 0000000000..274c95fade --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/table/writer.go @@ -0,0 +1,379 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package table + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/syndtr/gosnappy/snappy" + + "github.com/syndtr/goleveldb/leveldb/comparer" + "github.com/syndtr/goleveldb/leveldb/filter" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func sharedPrefixLen(a, b []byte) int { + i, n := 0, len(a) + if n > len(b) { + n = len(b) + } + for i < n && a[i] == b[i] { + i++ + } + return i +} + +type blockWriter struct { + restartInterval int + buf util.Buffer + nEntries int + prevKey []byte + restarts []uint32 + scratch []byte +} + +func (w *blockWriter) append(key, value []byte) { + nShared := 0 + if w.nEntries%w.restartInterval == 0 { + w.restarts = append(w.restarts, uint32(w.buf.Len())) + } else { + nShared = sharedPrefixLen(w.prevKey, key) + } + n := binary.PutUvarint(w.scratch[0:], uint64(nShared)) + n += binary.PutUvarint(w.scratch[n:], uint64(len(key)-nShared)) + n += binary.PutUvarint(w.scratch[n:], uint64(len(value))) + w.buf.Write(w.scratch[:n]) + w.buf.Write(key[nShared:]) + w.buf.Write(value) + w.prevKey = append(w.prevKey[:0], key...) + w.nEntries++ +} + +func (w *blockWriter) finish() { + // Write restarts entry. + if w.nEntries == 0 { + // Must have at least one restart entry. + w.restarts = append(w.restarts, 0) + } + w.restarts = append(w.restarts, uint32(len(w.restarts))) + for _, x := range w.restarts { + buf4 := w.buf.Alloc(4) + binary.LittleEndian.PutUint32(buf4, x) + } +} + +func (w *blockWriter) reset() { + w.buf.Reset() + w.nEntries = 0 + w.restarts = w.restarts[:0] +} + +func (w *blockWriter) bytesLen() int { + restartsLen := len(w.restarts) + if restartsLen == 0 { + restartsLen = 1 + } + return w.buf.Len() + 4*restartsLen + 4 +} + +type filterWriter struct { + generator filter.FilterGenerator + buf util.Buffer + nKeys int + offsets []uint32 +} + +func (w *filterWriter) add(key []byte) { + if w.generator == nil { + return + } + w.generator.Add(key) + w.nKeys++ +} + +func (w *filterWriter) flush(offset uint64) { + if w.generator == nil { + return + } + for x := int(offset / filterBase); x > len(w.offsets); { + w.generate() + } +} + +func (w *filterWriter) finish() { + if w.generator == nil { + return + } + // Generate last keys. + + if w.nKeys > 0 { + w.generate() + } + w.offsets = append(w.offsets, uint32(w.buf.Len())) + for _, x := range w.offsets { + buf4 := w.buf.Alloc(4) + binary.LittleEndian.PutUint32(buf4, x) + } + w.buf.WriteByte(filterBaseLg) +} + +func (w *filterWriter) generate() { + // Record offset. + w.offsets = append(w.offsets, uint32(w.buf.Len())) + // Generate filters. + if w.nKeys > 0 { + w.generator.Generate(&w.buf) + w.nKeys = 0 + } +} + +// Writer is a table writer. +type Writer struct { + writer io.Writer + err error + // Options + cmp comparer.Comparer + filter filter.Filter + compression opt.Compression + blockSize int + + dataBlock blockWriter + indexBlock blockWriter + filterBlock filterWriter + pendingBH blockHandle + offset uint64 + nEntries int + // Scratch allocated enough for 5 uvarint. Block writer should not use + // first 20-bytes since it will be used to encode block handle, which + // then passed to the block writer itself. + scratch [50]byte + comparerScratch []byte + compressionScratch []byte +} + +func (w *Writer) writeBlock(buf *util.Buffer, compression opt.Compression) (bh blockHandle, err error) { + // Compress the buffer if necessary. + var b []byte + if compression == opt.SnappyCompression { + // Allocate scratch enough for compression and block trailer. + if n := snappy.MaxEncodedLen(buf.Len()) + blockTrailerLen; len(w.compressionScratch) < n { + w.compressionScratch = make([]byte, n) + } + var compressed []byte + compressed, err = snappy.Encode(w.compressionScratch, buf.Bytes()) + if err != nil { + return + } + n := len(compressed) + b = compressed[:n+blockTrailerLen] + b[n] = blockTypeSnappyCompression + } else { + tmp := buf.Alloc(blockTrailerLen) + tmp[0] = blockTypeNoCompression + b = buf.Bytes() + } + + // Calculate the checksum. + n := len(b) - 4 + checksum := util.NewCRC(b[:n]).Value() + binary.LittleEndian.PutUint32(b[n:], checksum) + + // Write the buffer to the file. + _, err = w.writer.Write(b) + if err != nil { + return + } + bh = blockHandle{w.offset, uint64(len(b) - blockTrailerLen)} + w.offset += uint64(len(b)) + return +} + +func (w *Writer) flushPendingBH(key []byte) { + if w.pendingBH.length == 0 { + return + } + var separator []byte + if len(key) == 0 { + separator = w.cmp.Successor(w.comparerScratch[:0], w.dataBlock.prevKey) + } else { + separator = w.cmp.Separator(w.comparerScratch[:0], w.dataBlock.prevKey, key) + } + if separator == nil { + separator = w.dataBlock.prevKey + } else { + w.comparerScratch = separator + } + n := encodeBlockHandle(w.scratch[:20], w.pendingBH) + // Append the block handle to the index block. + w.indexBlock.append(separator, w.scratch[:n]) + // Reset prev key of the data block. + w.dataBlock.prevKey = w.dataBlock.prevKey[:0] + // Clear pending block handle. + w.pendingBH = blockHandle{} +} + +func (w *Writer) finishBlock() error { + w.dataBlock.finish() + bh, err := w.writeBlock(&w.dataBlock.buf, w.compression) + if err != nil { + return err + } + w.pendingBH = bh + // Reset the data block. + w.dataBlock.reset() + // Flush the filter block. + w.filterBlock.flush(w.offset) + return nil +} + +// Append appends key/value pair to the table. The keys passed must +// be in increasing order. +// +// It is safe to modify the contents of the arguments after Append returns. +func (w *Writer) Append(key, value []byte) error { + if w.err != nil { + return w.err + } + if w.nEntries > 0 && w.cmp.Compare(w.dataBlock.prevKey, key) >= 0 { + w.err = fmt.Errorf("leveldb/table: Writer: keys are not in increasing order: %q, %q", w.dataBlock.prevKey, key) + return w.err + } + + w.flushPendingBH(key) + // Append key/value pair to the data block. + w.dataBlock.append(key, value) + // Add key to the filter block. + w.filterBlock.add(key) + + // Finish the data block if block size target reached. + if w.dataBlock.bytesLen() >= w.blockSize { + if err := w.finishBlock(); err != nil { + w.err = err + return w.err + } + } + w.nEntries++ + return nil +} + +// BlocksLen returns number of blocks written so far. +func (w *Writer) BlocksLen() int { + n := w.indexBlock.nEntries + if w.pendingBH.length > 0 { + // Includes the pending block. + n++ + } + return n +} + +// EntriesLen returns number of entries added so far. +func (w *Writer) EntriesLen() int { + return w.nEntries +} + +// BytesLen returns number of bytes written so far. +func (w *Writer) BytesLen() int { + return int(w.offset) +} + +// Close will finalize the table. Calling Append is not possible +// after Close, but calling BlocksLen, EntriesLen and BytesLen +// is still possible. +func (w *Writer) Close() error { + if w.err != nil { + return w.err + } + + // Write the last data block. Or empty data block if there + // aren't any data blocks at all. + if w.dataBlock.nEntries > 0 || w.nEntries == 0 { + if err := w.finishBlock(); err != nil { + w.err = err + return w.err + } + } + w.flushPendingBH(nil) + + // Write the filter block. + var filterBH blockHandle + w.filterBlock.finish() + if buf := &w.filterBlock.buf; buf.Len() > 0 { + filterBH, w.err = w.writeBlock(buf, opt.NoCompression) + if w.err != nil { + return w.err + } + } + + // Write the metaindex block. + if filterBH.length > 0 { + key := []byte("filter." + w.filter.Name()) + n := encodeBlockHandle(w.scratch[:20], filterBH) + w.dataBlock.append(key, w.scratch[:n]) + } + w.dataBlock.finish() + metaindexBH, err := w.writeBlock(&w.dataBlock.buf, w.compression) + if err != nil { + w.err = err + return w.err + } + + // Write the index block. + w.indexBlock.finish() + indexBH, err := w.writeBlock(&w.indexBlock.buf, w.compression) + if err != nil { + w.err = err + return w.err + } + + // Write the table footer. + footer := w.scratch[:footerLen] + for i := range footer { + footer[i] = 0 + } + n := encodeBlockHandle(footer, metaindexBH) + encodeBlockHandle(footer[n:], indexBH) + copy(footer[footerLen-len(magic):], magic) + if _, err := w.writer.Write(footer); err != nil { + w.err = err + return w.err + } + w.offset += footerLen + + w.err = errors.New("leveldb/table: writer is closed") + return nil +} + +// NewWriter creates a new initialized table writer for the file. +// +// Table writer is not goroutine-safe. +func NewWriter(f io.Writer, o *opt.Options) *Writer { + w := &Writer{ + writer: f, + cmp: o.GetComparer(), + filter: o.GetFilter(), + compression: o.GetCompression(), + blockSize: o.GetBlockSize(), + comparerScratch: make([]byte, 0), + } + // data block + w.dataBlock.restartInterval = o.GetBlockRestartInterval() + // The first 20-bytes are used for encoding block handle. + w.dataBlock.scratch = w.scratch[20:] + // index block + w.indexBlock.restartInterval = 1 + w.indexBlock.scratch = w.scratch[20:] + // filter block + if w.filter != nil { + w.filterBlock.generator = w.filter.NewGenerator() + w.filterBlock.flush(0) + } + return w +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/db.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/db.go new file mode 100644 index 0000000000..ec3f177a12 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/db.go @@ -0,0 +1,222 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type DB interface{} + +type Put interface { + TestPut(key []byte, value []byte) error +} + +type Delete interface { + TestDelete(key []byte) error +} + +type Find interface { + TestFind(key []byte) (rkey, rvalue []byte, err error) +} + +type Get interface { + TestGet(key []byte) (value []byte, err error) +} + +type Has interface { + TestHas(key []byte) (ret bool, err error) +} + +type NewIterator interface { + TestNewIterator(slice *util.Range) iterator.Iterator +} + +type DBAct int + +func (a DBAct) String() string { + switch a { + case DBNone: + return "none" + case DBPut: + return "put" + case DBOverwrite: + return "overwrite" + case DBDelete: + return "delete" + case DBDeleteNA: + return "delete_na" + } + return "unknown" +} + +const ( + DBNone DBAct = iota + DBPut + DBOverwrite + DBDelete + DBDeleteNA +) + +type DBTesting struct { + Rand *rand.Rand + DB interface { + Get + Put + Delete + } + PostFn func(t *DBTesting) + Deleted, Present KeyValue + Act, LastAct DBAct + ActKey, LastActKey []byte +} + +func (t *DBTesting) post() { + if t.PostFn != nil { + t.PostFn(t) + } +} + +func (t *DBTesting) setAct(act DBAct, key []byte) { + t.LastAct, t.Act = t.Act, act + t.LastActKey, t.ActKey = t.ActKey, key +} + +func (t *DBTesting) text() string { + return fmt.Sprintf("last action was <%v> %q, <%v> %q", t.LastAct, t.LastActKey, t.Act, t.ActKey) +} + +func (t *DBTesting) Text() string { + return "DBTesting " + t.text() +} + +func (t *DBTesting) TestPresentKV(key, value []byte) { + rvalue, err := t.DB.TestGet(key) + Expect(err).ShouldNot(HaveOccurred(), "Get on key %q, %s", key, t.text()) + Expect(rvalue).Should(Equal(value), "Value for key %q, %s", key, t.text()) +} + +func (t *DBTesting) TestAllPresent() { + t.Present.IterateShuffled(t.Rand, func(i int, key, value []byte) { + t.TestPresentKV(key, value) + }) +} + +func (t *DBTesting) TestDeletedKey(key []byte) { + _, err := t.DB.TestGet(key) + Expect(err).Should(Equal(errors.ErrNotFound), "Get on deleted key %q, %s", key, t.text()) +} + +func (t *DBTesting) TestAllDeleted() { + t.Deleted.IterateShuffled(t.Rand, func(i int, key, value []byte) { + t.TestDeletedKey(key) + }) +} + +func (t *DBTesting) TestAll() { + dn := t.Deleted.Len() + pn := t.Present.Len() + ShuffledIndex(t.Rand, dn+pn, 1, func(i int) { + if i >= dn { + key, value := t.Present.Index(i - dn) + t.TestPresentKV(key, value) + } else { + t.TestDeletedKey(t.Deleted.KeyAt(i)) + } + }) +} + +func (t *DBTesting) Put(key, value []byte) { + if new := t.Present.PutU(key, value); new { + t.setAct(DBPut, key) + } else { + t.setAct(DBOverwrite, key) + } + t.Deleted.Delete(key) + err := t.DB.TestPut(key, value) + Expect(err).ShouldNot(HaveOccurred(), t.Text()) + t.TestPresentKV(key, value) + t.post() +} + +func (t *DBTesting) PutRandom() bool { + if t.Deleted.Len() > 0 { + i := t.Rand.Intn(t.Deleted.Len()) + key, value := t.Deleted.Index(i) + t.Put(key, value) + return true + } + return false +} + +func (t *DBTesting) Delete(key []byte) { + if exist, value := t.Present.Delete(key); exist { + t.setAct(DBDelete, key) + t.Deleted.PutU(key, value) + } else { + t.setAct(DBDeleteNA, key) + } + err := t.DB.TestDelete(key) + Expect(err).ShouldNot(HaveOccurred(), t.Text()) + t.TestDeletedKey(key) + t.post() +} + +func (t *DBTesting) DeleteRandom() bool { + if t.Present.Len() > 0 { + i := t.Rand.Intn(t.Present.Len()) + t.Delete(t.Present.KeyAt(i)) + return true + } + return false +} + +func (t *DBTesting) RandomAct(round int) { + for i := 0; i < round; i++ { + if t.Rand.Int()%2 == 0 { + t.PutRandom() + } else { + t.DeleteRandom() + } + } +} + +func DoDBTesting(t *DBTesting) { + if t.Rand == nil { + t.Rand = NewRand() + } + + t.DeleteRandom() + t.PutRandom() + t.DeleteRandom() + t.DeleteRandom() + for i := t.Deleted.Len() / 2; i >= 0; i-- { + t.PutRandom() + } + t.RandomAct((t.Deleted.Len() + t.Present.Len()) * 10) + + // Additional iterator testing + if db, ok := t.DB.(NewIterator); ok { + iter := db.TestNewIterator(nil) + Expect(iter.Error()).NotTo(HaveOccurred()) + + it := IteratorTesting{ + KeyValue: t.Present, + Iter: iter, + } + + DoIteratorTesting(&it) + iter.Release() + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go new file mode 100644 index 0000000000..82f3d0e811 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/ginkgo.go @@ -0,0 +1,21 @@ +package testutil + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func RunSuite(t GinkgoTestingT, name string) { + RunDefer() + + SynchronizedBeforeSuite(func() []byte { + RunDefer("setup") + return nil + }, func(data []byte) {}) + SynchronizedAfterSuite(func() { + RunDefer("teardown") + }, func() {}) + + RegisterFailHandler(Fail) + RunSpecs(t, name) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/iter.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/iter.go new file mode 100644 index 0000000000..df6d9db6a3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/iter.go @@ -0,0 +1,327 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/iterator" +) + +type IterAct int + +func (a IterAct) String() string { + switch a { + case IterNone: + return "none" + case IterFirst: + return "first" + case IterLast: + return "last" + case IterPrev: + return "prev" + case IterNext: + return "next" + case IterSeek: + return "seek" + case IterSOI: + return "soi" + case IterEOI: + return "eoi" + } + return "unknown" +} + +const ( + IterNone IterAct = iota + IterFirst + IterLast + IterPrev + IterNext + IterSeek + IterSOI + IterEOI +) + +type IteratorTesting struct { + KeyValue + Iter iterator.Iterator + Rand *rand.Rand + PostFn func(t *IteratorTesting) + Pos int + Act, LastAct IterAct + + once bool +} + +func (t *IteratorTesting) init() { + if !t.once { + t.Pos = -1 + t.once = true + } +} + +func (t *IteratorTesting) post() { + if t.PostFn != nil { + t.PostFn(t) + } +} + +func (t *IteratorTesting) setAct(act IterAct) { + t.LastAct, t.Act = t.Act, act +} + +func (t *IteratorTesting) text() string { + return fmt.Sprintf("at pos %d and last action was <%v> -> <%v>", t.Pos, t.LastAct, t.Act) +} + +func (t *IteratorTesting) Text() string { + return "IteratorTesting is " + t.text() +} + +func (t *IteratorTesting) IsFirst() bool { + t.init() + return t.Len() > 0 && t.Pos == 0 +} + +func (t *IteratorTesting) IsLast() bool { + t.init() + return t.Len() > 0 && t.Pos == t.Len()-1 +} + +func (t *IteratorTesting) TestKV() { + t.init() + key, value := t.Index(t.Pos) + Expect(t.Iter.Key()).NotTo(BeNil()) + Expect(t.Iter.Key()).Should(Equal(key), "Key is invalid, %s", t.text()) + Expect(t.Iter.Value()).Should(Equal(value), "Value for key %q, %s", key, t.text()) +} + +func (t *IteratorTesting) First() { + t.init() + t.setAct(IterFirst) + + ok := t.Iter.First() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Len() > 0 { + t.Pos = 0 + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = -1 + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Last() { + t.init() + t.setAct(IterLast) + + ok := t.Iter.Last() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Len() > 0 { + t.Pos = t.Len() - 1 + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = 0 + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Next() { + t.init() + t.setAct(IterNext) + + ok := t.Iter.Next() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Pos < t.Len()-1 { + t.Pos++ + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = t.Len() + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Prev() { + t.init() + t.setAct(IterPrev) + + ok := t.Iter.Prev() + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if t.Pos > 0 { + t.Pos-- + Expect(ok).Should(BeTrue(), t.Text()) + t.TestKV() + } else { + t.Pos = -1 + Expect(ok).ShouldNot(BeTrue(), t.Text()) + } + t.post() +} + +func (t *IteratorTesting) Seek(i int) { + t.init() + t.setAct(IterSeek) + + key, _ := t.Index(i) + oldKey, _ := t.IndexOrNil(t.Pos) + + ok := t.Iter.Seek(key) + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + Expect(ok).Should(BeTrue(), fmt.Sprintf("Seek from key %q to %q, to pos %d, %s", oldKey, key, i, t.text())) + + t.Pos = i + t.TestKV() + t.post() +} + +func (t *IteratorTesting) SeekInexact(i int) { + t.init() + t.setAct(IterSeek) + var key0 []byte + key1, _ := t.Index(i) + if i > 0 { + key0, _ = t.Index(i - 1) + } + key := BytesSeparator(key0, key1) + oldKey, _ := t.IndexOrNil(t.Pos) + + ok := t.Iter.Seek(key) + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + Expect(ok).Should(BeTrue(), fmt.Sprintf("Seek from key %q to %q (%q), to pos %d, %s", oldKey, key, key1, i, t.text())) + + t.Pos = i + t.TestKV() + t.post() +} + +func (t *IteratorTesting) SeekKey(key []byte) { + t.init() + t.setAct(IterSeek) + oldKey, _ := t.IndexOrNil(t.Pos) + i := t.Search(key) + + ok := t.Iter.Seek(key) + Expect(t.Iter.Error()).ShouldNot(HaveOccurred()) + if i < t.Len() { + key_, _ := t.Index(i) + Expect(ok).Should(BeTrue(), fmt.Sprintf("Seek from key %q to %q (%q), to pos %d, %s", oldKey, key, key_, i, t.text())) + t.Pos = i + t.TestKV() + } else { + Expect(ok).ShouldNot(BeTrue(), fmt.Sprintf("Seek from key %q to %q, %s", oldKey, key, t.text())) + } + + t.Pos = i + t.post() +} + +func (t *IteratorTesting) SOI() { + t.init() + t.setAct(IterSOI) + Expect(t.Pos).Should(BeNumerically("<=", 0), t.Text()) + for i := 0; i < 3; i++ { + t.Prev() + } + t.post() +} + +func (t *IteratorTesting) EOI() { + t.init() + t.setAct(IterEOI) + Expect(t.Pos).Should(BeNumerically(">=", t.Len()-1), t.Text()) + for i := 0; i < 3; i++ { + t.Next() + } + t.post() +} + +func (t *IteratorTesting) WalkPrev(fn func(t *IteratorTesting)) { + t.init() + for old := t.Pos; t.Pos > 0; old = t.Pos { + fn(t) + Expect(t.Pos).Should(BeNumerically("<", old), t.Text()) + } +} + +func (t *IteratorTesting) WalkNext(fn func(t *IteratorTesting)) { + t.init() + for old := t.Pos; t.Pos < t.Len()-1; old = t.Pos { + fn(t) + Expect(t.Pos).Should(BeNumerically(">", old), t.Text()) + } +} + +func (t *IteratorTesting) PrevAll() { + t.WalkPrev(func(t *IteratorTesting) { + t.Prev() + }) +} + +func (t *IteratorTesting) NextAll() { + t.WalkNext(func(t *IteratorTesting) { + t.Next() + }) +} + +func DoIteratorTesting(t *IteratorTesting) { + if t.Rand == nil { + t.Rand = NewRand() + } + t.SOI() + t.NextAll() + t.First() + t.SOI() + t.NextAll() + t.EOI() + t.PrevAll() + t.Last() + t.EOI() + t.PrevAll() + t.SOI() + + t.NextAll() + t.PrevAll() + t.NextAll() + t.Last() + t.PrevAll() + t.First() + t.NextAll() + t.EOI() + + ShuffledIndex(t.Rand, t.Len(), 1, func(i int) { + t.Seek(i) + }) + + ShuffledIndex(t.Rand, t.Len(), 1, func(i int) { + t.SeekInexact(i) + }) + + ShuffledIndex(t.Rand, t.Len(), 1, func(i int) { + t.Seek(i) + if i%2 != 0 { + t.PrevAll() + t.SOI() + } else { + t.NextAll() + t.EOI() + } + }) + + for _, key := range []string{"", "foo", "bar", "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"} { + t.SeekKey([]byte(key)) + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kv.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kv.go new file mode 100644 index 0000000000..471d5708c3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kv.go @@ -0,0 +1,352 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + "sort" + "strings" + + "github.com/syndtr/goleveldb/leveldb/util" +) + +type KeyValueEntry struct { + key, value []byte +} + +type KeyValue struct { + entries []KeyValueEntry + nbytes int +} + +func (kv *KeyValue) Put(key, value []byte) { + if n := len(kv.entries); n > 0 && cmp.Compare(kv.entries[n-1].key, key) >= 0 { + panic(fmt.Sprintf("Put: keys are not in increasing order: %q, %q", kv.entries[n-1].key, key)) + } + kv.entries = append(kv.entries, KeyValueEntry{key, value}) + kv.nbytes += len(key) + len(value) +} + +func (kv *KeyValue) PutString(key, value string) { + kv.Put([]byte(key), []byte(value)) +} + +func (kv *KeyValue) PutU(key, value []byte) bool { + if i, exist := kv.Get(key); !exist { + if i < kv.Len() { + kv.entries = append(kv.entries[:i+1], kv.entries[i:]...) + kv.entries[i] = KeyValueEntry{key, value} + } else { + kv.entries = append(kv.entries, KeyValueEntry{key, value}) + } + kv.nbytes += len(key) + len(value) + return true + } else { + kv.nbytes += len(value) - len(kv.ValueAt(i)) + kv.entries[i].value = value + } + return false +} + +func (kv *KeyValue) PutUString(key, value string) bool { + return kv.PutU([]byte(key), []byte(value)) +} + +func (kv *KeyValue) Delete(key []byte) (exist bool, value []byte) { + i, exist := kv.Get(key) + if exist { + value = kv.entries[i].value + kv.DeleteIndex(i) + } + return +} + +func (kv *KeyValue) DeleteIndex(i int) bool { + if i < kv.Len() { + kv.nbytes -= len(kv.KeyAt(i)) + len(kv.ValueAt(i)) + kv.entries = append(kv.entries[:i], kv.entries[i+1:]...) + return true + } + return false +} + +func (kv KeyValue) Len() int { + return len(kv.entries) +} + +func (kv *KeyValue) Size() int { + return kv.nbytes +} + +func (kv KeyValue) KeyAt(i int) []byte { + return kv.entries[i].key +} + +func (kv KeyValue) ValueAt(i int) []byte { + return kv.entries[i].value +} + +func (kv KeyValue) Index(i int) (key, value []byte) { + if i < 0 || i >= len(kv.entries) { + panic(fmt.Sprintf("Index #%d: out of range", i)) + } + return kv.entries[i].key, kv.entries[i].value +} + +func (kv KeyValue) IndexInexact(i int) (key_, key, value []byte) { + key, value = kv.Index(i) + var key0 []byte + var key1 = kv.KeyAt(i) + if i > 0 { + key0 = kv.KeyAt(i - 1) + } + key_ = BytesSeparator(key0, key1) + return +} + +func (kv KeyValue) IndexOrNil(i int) (key, value []byte) { + if i >= 0 && i < len(kv.entries) { + return kv.entries[i].key, kv.entries[i].value + } + return nil, nil +} + +func (kv KeyValue) IndexString(i int) (key, value string) { + key_, _value := kv.Index(i) + return string(key_), string(_value) +} + +func (kv KeyValue) Search(key []byte) int { + return sort.Search(kv.Len(), func(i int) bool { + return cmp.Compare(kv.KeyAt(i), key) >= 0 + }) +} + +func (kv KeyValue) SearchString(key string) int { + return kv.Search([]byte(key)) +} + +func (kv KeyValue) Get(key []byte) (i int, exist bool) { + i = kv.Search(key) + if i < kv.Len() && cmp.Compare(kv.KeyAt(i), key) == 0 { + exist = true + } + return +} + +func (kv KeyValue) GetString(key string) (i int, exist bool) { + return kv.Get([]byte(key)) +} + +func (kv KeyValue) Iterate(fn func(i int, key, value []byte)) { + for i, x := range kv.entries { + fn(i, x.key, x.value) + } +} + +func (kv KeyValue) IterateString(fn func(i int, key, value string)) { + kv.Iterate(func(i int, key, value []byte) { + fn(i, string(key), string(value)) + }) +} + +func (kv KeyValue) IterateShuffled(rnd *rand.Rand, fn func(i int, key, value []byte)) { + ShuffledIndex(rnd, kv.Len(), 1, func(i int) { + fn(i, kv.entries[i].key, kv.entries[i].value) + }) +} + +func (kv KeyValue) IterateShuffledString(rnd *rand.Rand, fn func(i int, key, value string)) { + kv.IterateShuffled(rnd, func(i int, key, value []byte) { + fn(i, string(key), string(value)) + }) +} + +func (kv KeyValue) IterateInexact(fn func(i int, key_, key, value []byte)) { + for i := range kv.entries { + key_, key, value := kv.IndexInexact(i) + fn(i, key_, key, value) + } +} + +func (kv KeyValue) IterateInexactString(fn func(i int, key_, key, value string)) { + kv.IterateInexact(func(i int, key_, key, value []byte) { + fn(i, string(key_), string(key), string(value)) + }) +} + +func (kv KeyValue) Clone() KeyValue { + return KeyValue{append([]KeyValueEntry{}, kv.entries...), kv.nbytes} +} + +func (kv KeyValue) Slice(start, limit int) KeyValue { + if start < 0 || limit > kv.Len() { + panic(fmt.Sprintf("Slice %d .. %d: out of range", start, limit)) + } else if limit < start { + panic(fmt.Sprintf("Slice %d .. %d: invalid range", start, limit)) + } + return KeyValue{append([]KeyValueEntry{}, kv.entries[start:limit]...), kv.nbytes} +} + +func (kv KeyValue) SliceKey(start, limit []byte) KeyValue { + start_ := 0 + limit_ := kv.Len() + if start != nil { + start_ = kv.Search(start) + } + if limit != nil { + limit_ = kv.Search(limit) + } + return kv.Slice(start_, limit_) +} + +func (kv KeyValue) SliceKeyString(start, limit string) KeyValue { + return kv.SliceKey([]byte(start), []byte(limit)) +} + +func (kv KeyValue) SliceRange(r *util.Range) KeyValue { + if r != nil { + return kv.SliceKey(r.Start, r.Limit) + } + return kv.Clone() +} + +func (kv KeyValue) Range(start, limit int) (r util.Range) { + if kv.Len() > 0 { + if start == kv.Len() { + r.Start = BytesAfter(kv.KeyAt(start - 1)) + } else { + r.Start = kv.KeyAt(start) + } + } + if limit < kv.Len() { + r.Limit = kv.KeyAt(limit) + } + return +} + +func KeyValue_EmptyKey() *KeyValue { + kv := &KeyValue{} + kv.PutString("", "v") + return kv +} + +func KeyValue_EmptyValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("abc", "") + kv.PutString("abcd", "") + return kv +} + +func KeyValue_OneKeyValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("abc", "v") + return kv +} + +func KeyValue_BigValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("big1", strings.Repeat("1", 200000)) + return kv +} + +func KeyValue_SpecialKey() *KeyValue { + kv := &KeyValue{} + kv.PutString("\xff\xff", "v3") + return kv +} + +func KeyValue_MultipleKeyValue() *KeyValue { + kv := &KeyValue{} + kv.PutString("a", "v") + kv.PutString("aa", "v1") + kv.PutString("aaa", "v2") + kv.PutString("aaacccccccccc", "v2") + kv.PutString("aaaccccccccccd", "v3") + kv.PutString("aaaccccccccccf", "v4") + kv.PutString("aaaccccccccccfg", "v5") + kv.PutString("ab", "v6") + kv.PutString("abc", "v7") + kv.PutString("abcd", "v8") + kv.PutString("accccccccccccccc", "v9") + kv.PutString("b", "v10") + kv.PutString("bb", "v11") + kv.PutString("bc", "v12") + kv.PutString("c", "v13") + kv.PutString("c1", "v13") + kv.PutString("czzzzzzzzzzzzzz", "v14") + kv.PutString("fffffffffffffff", "v15") + kv.PutString("g11", "v15") + kv.PutString("g111", "v15") + kv.PutString("g111\xff", "v15") + kv.PutString("zz", "v16") + kv.PutString("zzzzzzz", "v16") + kv.PutString("zzzzzzzzzzzzzzzz", "v16") + return kv +} + +var keymap = []byte("012345678ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy") + +func KeyValue_Generate(rnd *rand.Rand, n, minlen, maxlen, vminlen, vmaxlen int) *KeyValue { + if rnd == nil { + rnd = NewRand() + } + if maxlen < minlen { + panic("max len should >= min len") + } + + rrand := func(min, max int) int { + if min == max { + return max + } + return rnd.Intn(max-min) + min + } + + kv := &KeyValue{} + endC := byte(len(keymap) - 1) + gen := make([]byte, 0, maxlen) + for i := 0; i < n; i++ { + m := rrand(minlen, maxlen) + last := gen + retry: + gen = last[:m] + if k := len(last); m > k { + for j := k; j < m; j++ { + gen[j] = 0 + } + } else { + for j := m - 1; j >= 0; j-- { + c := last[j] + if c == endC { + continue + } + gen[j] = c + 1 + for j += 1; j < m; j++ { + gen[j] = 0 + } + goto ok + } + if m < maxlen { + m++ + goto retry + } + panic(fmt.Sprintf("only able to generate %d keys out of %d keys, try increasing max len", kv.Len(), n)) + ok: + } + key := make([]byte, m) + for j := 0; j < m; j++ { + key[j] = keymap[gen[j]] + } + value := make([]byte, rrand(vminlen, vmaxlen)) + for n := copy(value, []byte(fmt.Sprintf("v%d", i))); n < len(value); n++ { + value[n] = 'x' + } + kv.Put(key, value) + } + return kv +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go new file mode 100644 index 0000000000..a0b58f0e72 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/kvtest.go @@ -0,0 +1,187 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "fmt" + "math/rand" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/util" +) + +func KeyValueTesting(rnd *rand.Rand, kv KeyValue, p DB, setup func(KeyValue) DB, teardown func(DB)) { + if rnd == nil { + rnd = NewRand() + } + + if p == nil { + BeforeEach(func() { + p = setup(kv) + }) + if teardown != nil { + AfterEach(func() { + teardown(p) + }) + } + } + + It("Should find all keys with Find", func() { + if db, ok := p.(Find); ok { + ShuffledIndex(nil, kv.Len(), 1, func(i int) { + key_, key, value := kv.IndexInexact(i) + + // Using exact key. + rkey, rvalue, err := db.TestFind(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q", key) + Expect(rkey).Should(Equal(key), "Key") + Expect(rvalue).Should(Equal(value), "Value for key %q", key) + + // Using inexact key. + rkey, rvalue, err = db.TestFind(key_) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q (%q)", key_, key) + Expect(rkey).Should(Equal(key)) + Expect(rvalue).Should(Equal(value), "Value for key %q (%q)", key_, key) + }) + } + }) + + It("Should return error if the key is not present", func() { + if db, ok := p.(Find); ok { + var key []byte + if kv.Len() > 0 { + key_, _ := kv.Index(kv.Len() - 1) + key = BytesAfter(key_) + } + rkey, _, err := db.TestFind(key) + Expect(err).Should(HaveOccurred(), "Find for key %q yield key %q", key, rkey) + Expect(err).Should(Equal(errors.ErrNotFound)) + } + }) + + It("Should only find exact key with Get", func() { + if db, ok := p.(Get); ok { + ShuffledIndex(nil, kv.Len(), 1, func(i int) { + key_, key, value := kv.IndexInexact(i) + + // Using exact key. + rvalue, err := db.TestGet(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q", key) + Expect(rvalue).Should(Equal(value), "Value for key %q", key) + + // Using inexact key. + if len(key_) > 0 { + _, err = db.TestGet(key_) + Expect(err).Should(HaveOccurred(), "Error for key %q", key_) + Expect(err).Should(Equal(errors.ErrNotFound)) + } + }) + } + }) + + It("Should only find present key with Has", func() { + if db, ok := p.(Has); ok { + ShuffledIndex(nil, kv.Len(), 1, func(i int) { + key_, key, _ := kv.IndexInexact(i) + + // Using exact key. + ret, err := db.TestHas(key) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q", key) + Expect(ret).Should(BeTrue(), "False for key %q", key) + + // Using inexact key. + if len(key_) > 0 { + ret, err = db.TestHas(key_) + Expect(err).ShouldNot(HaveOccurred(), "Error for key %q", key_) + Expect(ret).ShouldNot(BeTrue(), "True for key %q", key) + } + }) + } + }) + + TestIter := func(r *util.Range, _kv KeyValue) { + if db, ok := p.(NewIterator); ok { + iter := db.TestNewIterator(r) + Expect(iter.Error()).ShouldNot(HaveOccurred()) + + t := IteratorTesting{ + KeyValue: _kv, + Iter: iter, + } + + DoIteratorTesting(&t) + iter.Release() + } + } + + It("Should iterates and seeks correctly", func(done Done) { + TestIter(nil, kv.Clone()) + done <- true + }, 3.0) + + RandomIndex(rnd, kv.Len(), Min(kv.Len(), 50), func(i int) { + type slice struct { + r *util.Range + start, limit int + } + + key_, _, _ := kv.IndexInexact(i) + for _, x := range []slice{ + {&util.Range{Start: key_, Limit: nil}, i, kv.Len()}, + {&util.Range{Start: nil, Limit: key_}, 0, i}, + } { + It(fmt.Sprintf("Should iterates and seeks correctly of a slice %d .. %d", x.start, x.limit), func(done Done) { + TestIter(x.r, kv.Slice(x.start, x.limit)) + done <- true + }, 3.0) + } + }) + + RandomRange(rnd, kv.Len(), Min(kv.Len(), 50), func(start, limit int) { + It(fmt.Sprintf("Should iterates and seeks correctly of a slice %d .. %d", start, limit), func(done Done) { + r := kv.Range(start, limit) + TestIter(&r, kv.Slice(start, limit)) + done <- true + }, 3.0) + }) +} + +func AllKeyValueTesting(rnd *rand.Rand, body, setup func(KeyValue) DB, teardown func(DB)) { + Test := func(kv *KeyValue) func() { + return func() { + var p DB + if setup != nil { + Defer("setup", func() { + p = setup(*kv) + }) + } + if teardown != nil { + Defer("teardown", func() { + teardown(p) + }) + } + if body != nil { + p = body(*kv) + } + KeyValueTesting(rnd, *kv, p, func(KeyValue) DB { + return p + }, nil) + } + } + + Describe("with no key/value (empty)", Test(&KeyValue{})) + Describe("with empty key", Test(KeyValue_EmptyKey())) + Describe("with empty value", Test(KeyValue_EmptyValue())) + Describe("with one key/value", Test(KeyValue_OneKeyValue())) + Describe("with big value", Test(KeyValue_BigValue())) + Describe("with special key", Test(KeyValue_SpecialKey())) + Describe("with multiple key/value", Test(KeyValue_MultipleKeyValue())) + Describe("with generated key/value", Test(KeyValue_Generate(nil, 120, 1, 50, 10, 120))) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/storage.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/storage.go new file mode 100644 index 0000000000..59c496d54c --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/storage.go @@ -0,0 +1,586 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" +) + +var ( + storageMu sync.Mutex + storageUseFS bool = true + storageKeepFS bool = false + storageNum int +) + +type StorageMode int + +const ( + ModeOpen StorageMode = 1 << iota + ModeCreate + ModeRemove + ModeRead + ModeWrite + ModeSync + ModeClose +) + +const ( + modeOpen = iota + modeCreate + modeRemove + modeRead + modeWrite + modeSync + modeClose + + modeCount +) + +const ( + typeManifest = iota + typeJournal + typeTable + typeTemp + + typeCount +) + +const flattenCount = modeCount * typeCount + +func flattenType(m StorageMode, t storage.FileType) int { + var x int + switch m { + case ModeOpen: + x = modeOpen + case ModeCreate: + x = modeCreate + case ModeRemove: + x = modeRemove + case ModeRead: + x = modeRead + case ModeWrite: + x = modeWrite + case ModeSync: + x = modeSync + case ModeClose: + x = modeClose + default: + panic("invalid storage mode") + } + x *= typeCount + switch t { + case storage.TypeManifest: + return x + typeManifest + case storage.TypeJournal: + return x + typeJournal + case storage.TypeTable: + return x + typeTable + case storage.TypeTemp: + return x + typeTemp + default: + panic("invalid file type") + } +} + +func listFlattenType(m StorageMode, t storage.FileType) []int { + ret := make([]int, 0, flattenCount) + add := func(x int) { + x *= typeCount + switch { + case t&storage.TypeManifest != 0: + ret = append(ret, x+typeManifest) + case t&storage.TypeJournal != 0: + ret = append(ret, x+typeJournal) + case t&storage.TypeTable != 0: + ret = append(ret, x+typeTable) + case t&storage.TypeTemp != 0: + ret = append(ret, x+typeTemp) + } + } + switch { + case m&ModeOpen != 0: + add(modeOpen) + case m&ModeCreate != 0: + add(modeCreate) + case m&ModeRemove != 0: + add(modeRemove) + case m&ModeRead != 0: + add(modeRead) + case m&ModeWrite != 0: + add(modeWrite) + case m&ModeSync != 0: + add(modeSync) + case m&ModeClose != 0: + add(modeClose) + } + return ret +} + +func packFile(num uint64, t storage.FileType) uint64 { + if num>>(64-typeCount) != 0 { + panic("overflow") + } + return num<> typeCount, storage.FileType(x) & storage.TypeAll +} + +type emulatedError struct { + err error +} + +func (err emulatedError) Error() string { + return fmt.Sprintf("emulated storage error: %v", err.err) +} + +type storageLock struct { + s *Storage + r util.Releaser +} + +func (l storageLock) Release() { + l.r.Release() + l.s.logI("storage lock released") +} + +type reader struct { + f *file + storage.Reader +} + +func (r *reader) Read(p []byte) (n int, err error) { + err = r.f.s.emulateError(ModeRead, r.f.Type()) + if err == nil { + r.f.s.stall(ModeRead, r.f.Type()) + n, err = r.Reader.Read(p) + } + r.f.s.count(ModeRead, r.f.Type(), n) + if err != nil && err != io.EOF { + r.f.s.logI("read error, num=%d type=%v n=%d err=%v", r.f.Num(), r.f.Type(), n, err) + } + return +} + +func (r *reader) ReadAt(p []byte, off int64) (n int, err error) { + err = r.f.s.emulateError(ModeRead, r.f.Type()) + if err == nil { + r.f.s.stall(ModeRead, r.f.Type()) + n, err = r.Reader.ReadAt(p, off) + } + r.f.s.count(ModeRead, r.f.Type(), n) + if err != nil && err != io.EOF { + r.f.s.logI("readAt error, num=%d type=%v offset=%d n=%d err=%v", r.f.Num(), r.f.Type(), off, n, err) + } + return +} + +func (r *reader) Close() (err error) { + return r.f.doClose(r.Reader) +} + +type writer struct { + f *file + storage.Writer +} + +func (w *writer) Write(p []byte) (n int, err error) { + err = w.f.s.emulateError(ModeWrite, w.f.Type()) + if err == nil { + w.f.s.stall(ModeWrite, w.f.Type()) + n, err = w.Writer.Write(p) + } + w.f.s.count(ModeWrite, w.f.Type(), n) + if err != nil && err != io.EOF { + w.f.s.logI("write error, num=%d type=%v n=%d err=%v", w.f.Num(), w.f.Type(), n, err) + } + return +} + +func (w *writer) Sync() (err error) { + err = w.f.s.emulateError(ModeSync, w.f.Type()) + if err == nil { + w.f.s.stall(ModeSync, w.f.Type()) + err = w.Writer.Sync() + } + w.f.s.count(ModeSync, w.f.Type(), 0) + if err != nil { + w.f.s.logI("sync error, num=%d type=%v err=%v", w.f.Num(), w.f.Type(), err) + } + return +} + +func (w *writer) Close() (err error) { + return w.f.doClose(w.Writer) +} + +type file struct { + s *Storage + storage.File +} + +func (f *file) pack() uint64 { + return packFile(f.Num(), f.Type()) +} + +func (f *file) assertOpen() { + ExpectWithOffset(2, f.s.opens).NotTo(HaveKey(f.pack()), "File open, num=%d type=%v writer=%v", f.Num(), f.Type(), f.s.opens[f.pack()]) +} + +func (f *file) doClose(closer io.Closer) (err error) { + err = f.s.emulateError(ModeClose, f.Type()) + if err == nil { + f.s.stall(ModeClose, f.Type()) + } + f.s.mu.Lock() + defer f.s.mu.Unlock() + if err == nil { + ExpectWithOffset(2, f.s.opens).To(HaveKey(f.pack()), "File closed, num=%d type=%v", f.Num(), f.Type()) + err = closer.Close() + } + f.s.countNB(ModeClose, f.Type(), 0) + writer := f.s.opens[f.pack()] + if err != nil { + f.s.logISkip(1, "file close failed, num=%d type=%v writer=%v err=%v", f.Num(), f.Type(), writer, err) + } else { + f.s.logISkip(1, "file closed, num=%d type=%v writer=%v", f.Num(), f.Type(), writer) + delete(f.s.opens, f.pack()) + } + return +} + +func (f *file) Open() (r storage.Reader, err error) { + err = f.s.emulateError(ModeOpen, f.Type()) + if err == nil { + f.s.stall(ModeOpen, f.Type()) + } + f.s.mu.Lock() + defer f.s.mu.Unlock() + if err == nil { + f.assertOpen() + f.s.countNB(ModeOpen, f.Type(), 0) + r, err = f.File.Open() + } + if err != nil { + f.s.logI("file open failed, num=%d type=%v err=%v", f.Num(), f.Type(), err) + } else { + f.s.logI("file opened, num=%d type=%v", f.Num(), f.Type()) + f.s.opens[f.pack()] = false + r = &reader{f, r} + } + return +} + +func (f *file) Create() (w storage.Writer, err error) { + err = f.s.emulateError(ModeCreate, f.Type()) + if err == nil { + f.s.stall(ModeCreate, f.Type()) + } + f.s.mu.Lock() + defer f.s.mu.Unlock() + if err == nil { + f.assertOpen() + f.s.countNB(ModeCreate, f.Type(), 0) + w, err = f.File.Create() + } + if err != nil { + f.s.logI("file create failed, num=%d type=%v err=%v", f.Num(), f.Type(), err) + } else { + f.s.logI("file created, num=%d type=%v", f.Num(), f.Type()) + f.s.opens[f.pack()] = true + w = &writer{f, w} + } + return +} + +func (f *file) Remove() (err error) { + err = f.s.emulateError(ModeRemove, f.Type()) + if err == nil { + f.s.stall(ModeRemove, f.Type()) + } + f.s.mu.Lock() + defer f.s.mu.Unlock() + if err == nil { + f.assertOpen() + f.s.countNB(ModeRemove, f.Type(), 0) + err = f.File.Remove() + } + if err != nil { + f.s.logI("file remove failed, num=%d type=%v err=%v", f.Num(), f.Type(), err) + } else { + f.s.logI("file removed, num=%d type=%v", f.Num(), f.Type()) + } + return +} + +type Storage struct { + storage.Storage + closeFn func() error + + lmu sync.Mutex + lb bytes.Buffer + + mu sync.Mutex + // Open files, true=writer, false=reader + opens map[uint64]bool + counters [flattenCount]int + bytesCounter [flattenCount]int64 + emulatedError [flattenCount]error + stallCond sync.Cond + stalled [flattenCount]bool +} + +func (s *Storage) log(skip int, str string) { + s.lmu.Lock() + defer s.lmu.Unlock() + _, file, line, ok := runtime.Caller(skip + 2) + if ok { + // Truncate file name at last file name separator. + if index := strings.LastIndex(file, "/"); index >= 0 { + file = file[index+1:] + } else if index = strings.LastIndex(file, "\\"); index >= 0 { + file = file[index+1:] + } + } else { + file = "???" + line = 1 + } + fmt.Fprintf(&s.lb, "%s:%d: ", file, line) + lines := strings.Split(str, "\n") + if l := len(lines); l > 1 && lines[l-1] == "" { + lines = lines[:l-1] + } + for i, line := range lines { + if i > 0 { + s.lb.WriteString("\n\t") + } + s.lb.WriteString(line) + } + s.lb.WriteByte('\n') +} + +func (s *Storage) logISkip(skip int, format string, args ...interface{}) { + pc, _, _, ok := runtime.Caller(skip + 1) + if ok { + if f := runtime.FuncForPC(pc); f != nil { + fname := f.Name() + if index := strings.LastIndex(fname, "."); index >= 0 { + fname = fname[index+1:] + } + format = fname + ": " + format + } + } + s.log(skip+1, fmt.Sprintf(format, args...)) +} + +func (s *Storage) logI(format string, args ...interface{}) { + s.logISkip(1, format, args...) +} + +func (s *Storage) Log(str string) { + s.log(1, "Log: "+str) + s.Storage.Log(str) +} + +func (s *Storage) Lock() (r util.Releaser, err error) { + r, err = s.Storage.Lock() + if err != nil { + s.logI("storage locking failed, err=%v", err) + } else { + s.logI("storage locked") + r = storageLock{s, r} + } + return +} + +func (s *Storage) GetFile(num uint64, t storage.FileType) storage.File { + return &file{s, s.Storage.GetFile(num, t)} +} + +func (s *Storage) GetFiles(t storage.FileType) (files []storage.File, err error) { + rfiles, err := s.Storage.GetFiles(t) + if err != nil { + s.logI("get files failed, err=%v", err) + return + } + files = make([]storage.File, len(rfiles)) + for i, f := range rfiles { + files[i] = &file{s, f} + } + s.logI("get files, type=0x%x count=%d", int(t), len(files)) + return +} + +func (s *Storage) GetManifest() (f storage.File, err error) { + manifest, err := s.Storage.GetManifest() + if err != nil { + if !os.IsNotExist(err) { + s.logI("get manifest failed, err=%v", err) + } + return + } + s.logI("get manifest, num=%d", manifest.Num()) + return &file{s, manifest}, nil +} + +func (s *Storage) SetManifest(f storage.File) error { + f_, ok := f.(*file) + ExpectWithOffset(1, ok).To(BeTrue()) + ExpectWithOffset(1, f_.Type()).To(Equal(storage.TypeManifest)) + err := s.Storage.SetManifest(f_.File) + if err != nil { + s.logI("set manifest failed, err=%v", err) + } else { + s.logI("set manifest, num=%d", f_.Num()) + } + return err +} + +func (s *Storage) openFiles() string { + out := "Open files:" + for x, writer := range s.opens { + num, t := unpackFile(x) + out += fmt.Sprintf("\n · num=%d type=%v writer=%v", num, t, writer) + } + return out +} + +func (s *Storage) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + ExpectWithOffset(1, s.opens).To(BeEmpty(), s.openFiles()) + err := s.Storage.Close() + if err != nil { + s.logI("storage closing failed, err=%v", err) + } else { + s.logI("storage closed") + } + if s.closeFn != nil { + if err1 := s.closeFn(); err1 != nil { + s.logI("close func error, err=%v", err1) + } + } + return err +} + +func (s *Storage) countNB(m StorageMode, t storage.FileType, n int) { + s.counters[flattenType(m, t)]++ + s.bytesCounter[flattenType(m, t)] += int64(n) +} + +func (s *Storage) count(m StorageMode, t storage.FileType, n int) { + s.mu.Lock() + defer s.mu.Unlock() + s.countNB(m, t, n) +} + +func (s *Storage) ResetCounter(m StorageMode, t storage.FileType) { + for _, x := range listFlattenType(m, t) { + s.counters[x] = 0 + s.bytesCounter[x] = 0 + } +} + +func (s *Storage) Counter(m StorageMode, t storage.FileType) (count int, bytes int64) { + for _, x := range listFlattenType(m, t) { + count += s.counters[x] + bytes += s.bytesCounter[x] + } + return +} + +func (s *Storage) emulateError(m StorageMode, t storage.FileType) error { + s.mu.Lock() + defer s.mu.Unlock() + err := s.emulatedError[flattenType(m, t)] + if err != nil { + return emulatedError{err} + } + return nil +} + +func (s *Storage) EmulateError(m StorageMode, t storage.FileType, err error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.emulatedError[x] = err + } +} + +func (s *Storage) stall(m StorageMode, t storage.FileType) { + x := flattenType(m, t) + s.mu.Lock() + defer s.mu.Unlock() + for s.stalled[x] { + s.stallCond.Wait() + } +} + +func (s *Storage) Stall(m StorageMode, t storage.FileType) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.stalled[x] = true + } +} + +func (s *Storage) Release(m StorageMode, t storage.FileType) { + s.mu.Lock() + defer s.mu.Unlock() + for _, x := range listFlattenType(m, t) { + s.stalled[x] = false + } + s.stallCond.Broadcast() +} + +func NewStorage() *Storage { + var stor storage.Storage + var closeFn func() error + if storageUseFS { + for { + storageMu.Lock() + num := storageNum + storageNum++ + storageMu.Unlock() + path := filepath.Join(os.TempDir(), fmt.Sprintf("goleveldb-test%d0%d0%d", os.Getuid(), os.Getpid(), num)) + if _, err := os.Stat(path); os.IsNotExist(err) { + stor, err = storage.OpenFile(path) + ExpectWithOffset(1, err).NotTo(HaveOccurred(), "creating storage at %s", path) + closeFn = func() error { + if storageKeepFS { + return nil + } + return os.RemoveAll(path) + } + break + } + } + } else { + stor = storage.NewMemStorage() + } + s := &Storage{ + Storage: stor, + closeFn: closeFn, + opens: make(map[uint64]bool), + } + s.stallCond.L = &s.mu + return s +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/util.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/util.go new file mode 100644 index 0000000000..97c5294b1b --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil/util.go @@ -0,0 +1,171 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package testutil + +import ( + "bytes" + "flag" + "math/rand" + "reflect" + "sync" + + "github.com/onsi/ginkgo/config" + + "github.com/syndtr/goleveldb/leveldb/comparer" +) + +var ( + runfn = make(map[string][]func()) + runmu sync.Mutex +) + +func Defer(args ...interface{}) bool { + var ( + group string + fn func() + ) + for _, arg := range args { + v := reflect.ValueOf(arg) + switch v.Kind() { + case reflect.String: + group = v.String() + case reflect.Func: + r := reflect.ValueOf(&fn).Elem() + r.Set(v) + } + } + if fn != nil { + runmu.Lock() + runfn[group] = append(runfn[group], fn) + runmu.Unlock() + } + return true +} + +func RunDefer(groups ...string) bool { + if len(groups) == 0 { + groups = append(groups, "") + } + runmu.Lock() + var runfn_ []func() + for _, group := range groups { + runfn_ = append(runfn_, runfn[group]...) + delete(runfn, group) + } + runmu.Unlock() + for _, fn := range runfn_ { + fn() + } + return runfn_ != nil +} + +func RandomSeed() int64 { + if !flag.Parsed() { + panic("random seed not initialized") + } + return config.GinkgoConfig.RandomSeed +} + +func NewRand() *rand.Rand { + return rand.New(rand.NewSource(RandomSeed())) +} + +var cmp = comparer.DefaultComparer + +func BytesSeparator(a, b []byte) []byte { + if bytes.Equal(a, b) { + return b + } + i, n := 0, len(a) + if n > len(b) { + n = len(b) + } + for ; i < n && (a[i] == b[i]); i++ { + } + x := append([]byte{}, a[:i]...) + if i < n { + if c := a[i] + 1; c < b[i] { + return append(x, c) + } + x = append(x, a[i]) + i++ + } + for ; i < len(a); i++ { + if c := a[i]; c < 0xff { + return append(x, c+1) + } else { + x = append(x, c) + } + } + if len(b) > i && b[i] > 0 { + return append(x, b[i]-1) + } + return append(x, 'x') +} + +func BytesAfter(b []byte) []byte { + var x []byte + for _, c := range b { + if c < 0xff { + return append(x, c+1) + } else { + x = append(x, c) + } + } + return append(x, 'x') +} + +func RandomIndex(rnd *rand.Rand, n, round int, fn func(i int)) { + if rnd == nil { + rnd = NewRand() + } + for x := 0; x < round; x++ { + fn(rnd.Intn(n)) + } + return +} + +func ShuffledIndex(rnd *rand.Rand, n, round int, fn func(i int)) { + if rnd == nil { + rnd = NewRand() + } + for x := 0; x < round; x++ { + for _, i := range rnd.Perm(n) { + fn(i) + } + } + return +} + +func RandomRange(rnd *rand.Rand, n, round int, fn func(start, limit int)) { + if rnd == nil { + rnd = NewRand() + } + for x := 0; x < round; x++ { + start := rnd.Intn(n) + length := 0 + if j := n - start; j > 0 { + length = rnd.Intn(j) + } + fn(start, start+length) + } + return +} + +func Max(x, y int) int { + if x > y { + return x + } + return y +} + +func Min(x, y int) int { + if x < y { + return x + } + return y +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil_test.go new file mode 100644 index 0000000000..25bf2b29f9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/testutil_test.go @@ -0,0 +1,63 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + . "github.com/onsi/gomega" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/testutil" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type testingDB struct { + *DB + ro *opt.ReadOptions + wo *opt.WriteOptions + stor *testutil.Storage +} + +func (t *testingDB) TestPut(key []byte, value []byte) error { + return t.Put(key, value, t.wo) +} + +func (t *testingDB) TestDelete(key []byte) error { + return t.Delete(key, t.wo) +} + +func (t *testingDB) TestGet(key []byte) (value []byte, err error) { + return t.Get(key, t.ro) +} + +func (t *testingDB) TestHas(key []byte) (ret bool, err error) { + return t.Has(key, t.ro) +} + +func (t *testingDB) TestNewIterator(slice *util.Range) iterator.Iterator { + return t.NewIterator(slice, t.ro) +} + +func (t *testingDB) TestClose() { + err := t.Close() + ExpectWithOffset(1, err).NotTo(HaveOccurred()) + err = t.stor.Close() + ExpectWithOffset(1, err).NotTo(HaveOccurred()) +} + +func newTestingDB(o *opt.Options, ro *opt.ReadOptions, wo *opt.WriteOptions) *testingDB { + stor := testutil.NewStorage() + db, err := Open(stor, o) + // FIXME: This may be called from outside It, which may cause panic. + Expect(err).NotTo(HaveOccurred()) + return &testingDB{ + DB: db, + ro: ro, + wo: wo, + stor: stor, + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util.go new file mode 100644 index 0000000000..1a5bf71a32 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util.go @@ -0,0 +1,91 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "fmt" + "sort" + + "github.com/syndtr/goleveldb/leveldb/storage" +) + +func shorten(str string) string { + if len(str) <= 8 { + return str + } + return str[:3] + ".." + str[len(str)-3:] +} + +var bunits = [...]string{"", "Ki", "Mi", "Gi"} + +func shortenb(bytes int) string { + i := 0 + for ; bytes > 1024 && i < 4; i++ { + bytes /= 1024 + } + return fmt.Sprintf("%d%sB", bytes, bunits[i]) +} + +func sshortenb(bytes int) string { + if bytes == 0 { + return "~" + } + sign := "+" + if bytes < 0 { + sign = "-" + bytes *= -1 + } + i := 0 + for ; bytes > 1024 && i < 4; i++ { + bytes /= 1024 + } + return fmt.Sprintf("%s%d%sB", sign, bytes, bunits[i]) +} + +func sint(x int) string { + if x == 0 { + return "~" + } + sign := "+" + if x < 0 { + sign = "-" + x *= -1 + } + return fmt.Sprintf("%s%d", sign, x) +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +type files []storage.File + +func (p files) Len() int { + return len(p) +} + +func (p files) Less(i, j int) bool { + return p[i].Num() < p[j].Num() +} + +func (p files) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +func (p files) sort() { + sort.Sort(p) +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer.go new file mode 100644 index 0000000000..21de242552 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer.go @@ -0,0 +1,293 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package util + +// This a copy of Go std bytes.Buffer with some modification +// and some features stripped. + +import ( + "bytes" + "io" +) + +// A Buffer is a variable-sized buffer of bytes with Read and Write methods. +// The zero value for Buffer is an empty buffer ready to use. +type Buffer struct { + buf []byte // contents are the bytes buf[off : len(buf)] + off int // read at &buf[off], write at &buf[len(buf)] + bootstrap [64]byte // memory to hold first slice; helps small buffers (Printf) avoid allocation. +} + +// Bytes returns a slice of the contents of the unread portion of the buffer; +// len(b.Bytes()) == b.Len(). If the caller changes the contents of the +// returned slice, the contents of the buffer will change provided there +// are no intervening method calls on the Buffer. +func (b *Buffer) Bytes() []byte { return b.buf[b.off:] } + +// String returns the contents of the unread portion of the buffer +// as a string. If the Buffer is a nil pointer, it returns "". +func (b *Buffer) String() string { + if b == nil { + // Special case, useful in debugging. + return "" + } + return string(b.buf[b.off:]) +} + +// Len returns the number of bytes of the unread portion of the buffer; +// b.Len() == len(b.Bytes()). +func (b *Buffer) Len() int { return len(b.buf) - b.off } + +// Truncate discards all but the first n unread bytes from the buffer. +// It panics if n is negative or greater than the length of the buffer. +func (b *Buffer) Truncate(n int) { + switch { + case n < 0 || n > b.Len(): + panic("leveldb/util.Buffer: truncation out of range") + case n == 0: + // Reuse buffer space. + b.off = 0 + } + b.buf = b.buf[0 : b.off+n] +} + +// Reset resets the buffer so it has no content. +// b.Reset() is the same as b.Truncate(0). +func (b *Buffer) Reset() { b.Truncate(0) } + +// grow grows the buffer to guarantee space for n more bytes. +// It returns the index where bytes should be written. +// If the buffer can't grow it will panic with bytes.ErrTooLarge. +func (b *Buffer) grow(n int) int { + m := b.Len() + // If buffer is empty, reset to recover space. + if m == 0 && b.off != 0 { + b.Truncate(0) + } + if len(b.buf)+n > cap(b.buf) { + var buf []byte + if b.buf == nil && n <= len(b.bootstrap) { + buf = b.bootstrap[0:] + } else if m+n <= cap(b.buf)/2 { + // We can slide things down instead of allocating a new + // slice. We only need m+n <= cap(b.buf) to slide, but + // we instead let capacity get twice as large so we + // don't spend all our time copying. + copy(b.buf[:], b.buf[b.off:]) + buf = b.buf[:m] + } else { + // not enough space anywhere + buf = makeSlice(2*cap(b.buf) + n) + copy(buf, b.buf[b.off:]) + } + b.buf = buf + b.off = 0 + } + b.buf = b.buf[0 : b.off+m+n] + return b.off + m +} + +// Alloc allocs n bytes of slice from the buffer, growing the buffer as +// needed. If n is negative, Alloc will panic. +// If the buffer can't grow it will panic with bytes.ErrTooLarge. +func (b *Buffer) Alloc(n int) []byte { + if n < 0 { + panic("leveldb/util.Buffer.Alloc: negative count") + } + m := b.grow(n) + return b.buf[m:] +} + +// Grow grows the buffer's capacity, if necessary, to guarantee space for +// another n bytes. After Grow(n), at least n bytes can be written to the +// buffer without another allocation. +// If n is negative, Grow will panic. +// If the buffer can't grow it will panic with bytes.ErrTooLarge. +func (b *Buffer) Grow(n int) { + if n < 0 { + panic("leveldb/util.Buffer.Grow: negative count") + } + m := b.grow(n) + b.buf = b.buf[0:m] +} + +// Write appends the contents of p to the buffer, growing the buffer as +// needed. The return value n is the length of p; err is always nil. If the +// buffer becomes too large, Write will panic with bytes.ErrTooLarge. +func (b *Buffer) Write(p []byte) (n int, err error) { + m := b.grow(len(p)) + return copy(b.buf[m:], p), nil +} + +// MinRead is the minimum slice size passed to a Read call by +// Buffer.ReadFrom. As long as the Buffer has at least MinRead bytes beyond +// what is required to hold the contents of r, ReadFrom will not grow the +// underlying buffer. +const MinRead = 512 + +// ReadFrom reads data from r until EOF and appends it to the buffer, growing +// the buffer as needed. The return value n is the number of bytes read. Any +// error except io.EOF encountered during the read is also returned. If the +// buffer becomes too large, ReadFrom will panic with bytes.ErrTooLarge. +func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) { + // If buffer is empty, reset to recover space. + if b.off >= len(b.buf) { + b.Truncate(0) + } + for { + if free := cap(b.buf) - len(b.buf); free < MinRead { + // not enough space at end + newBuf := b.buf + if b.off+free < MinRead { + // not enough space using beginning of buffer; + // double buffer capacity + newBuf = makeSlice(2*cap(b.buf) + MinRead) + } + copy(newBuf, b.buf[b.off:]) + b.buf = newBuf[:len(b.buf)-b.off] + b.off = 0 + } + m, e := r.Read(b.buf[len(b.buf):cap(b.buf)]) + b.buf = b.buf[0 : len(b.buf)+m] + n += int64(m) + if e == io.EOF { + break + } + if e != nil { + return n, e + } + } + return n, nil // err is EOF, so return nil explicitly +} + +// makeSlice allocates a slice of size n. If the allocation fails, it panics +// with bytes.ErrTooLarge. +func makeSlice(n int) []byte { + // If the make fails, give a known error. + defer func() { + if recover() != nil { + panic(bytes.ErrTooLarge) + } + }() + return make([]byte, n) +} + +// WriteTo writes data to w until the buffer is drained or an error occurs. +// The return value n is the number of bytes written; it always fits into an +// int, but it is int64 to match the io.WriterTo interface. Any error +// encountered during the write is also returned. +func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) { + if b.off < len(b.buf) { + nBytes := b.Len() + m, e := w.Write(b.buf[b.off:]) + if m > nBytes { + panic("leveldb/util.Buffer.WriteTo: invalid Write count") + } + b.off += m + n = int64(m) + if e != nil { + return n, e + } + // all bytes should have been written, by definition of + // Write method in io.Writer + if m != nBytes { + return n, io.ErrShortWrite + } + } + // Buffer is now empty; reset. + b.Truncate(0) + return +} + +// WriteByte appends the byte c to the buffer, growing the buffer as needed. +// The returned error is always nil, but is included to match bufio.Writer's +// WriteByte. If the buffer becomes too large, WriteByte will panic with +// bytes.ErrTooLarge. +func (b *Buffer) WriteByte(c byte) error { + m := b.grow(1) + b.buf[m] = c + return nil +} + +// Read reads the next len(p) bytes from the buffer or until the buffer +// is drained. The return value n is the number of bytes read. If the +// buffer has no data to return, err is io.EOF (unless len(p) is zero); +// otherwise it is nil. +func (b *Buffer) Read(p []byte) (n int, err error) { + if b.off >= len(b.buf) { + // Buffer is empty, reset to recover space. + b.Truncate(0) + if len(p) == 0 { + return + } + return 0, io.EOF + } + n = copy(p, b.buf[b.off:]) + b.off += n + return +} + +// Next returns a slice containing the next n bytes from the buffer, +// advancing the buffer as if the bytes had been returned by Read. +// If there are fewer than n bytes in the buffer, Next returns the entire buffer. +// The slice is only valid until the next call to a read or write method. +func (b *Buffer) Next(n int) []byte { + m := b.Len() + if n > m { + n = m + } + data := b.buf[b.off : b.off+n] + b.off += n + return data +} + +// ReadByte reads and returns the next byte from the buffer. +// If no byte is available, it returns error io.EOF. +func (b *Buffer) ReadByte() (c byte, err error) { + if b.off >= len(b.buf) { + // Buffer is empty, reset to recover space. + b.Truncate(0) + return 0, io.EOF + } + c = b.buf[b.off] + b.off++ + return c, nil +} + +// ReadBytes reads until the first occurrence of delim in the input, +// returning a slice containing the data up to and including the delimiter. +// If ReadBytes encounters an error before finding a delimiter, +// it returns the data read before the error and the error itself (often io.EOF). +// ReadBytes returns err != nil if and only if the returned data does not end in +// delim. +func (b *Buffer) ReadBytes(delim byte) (line []byte, err error) { + slice, err := b.readSlice(delim) + // return a copy of slice. The buffer's backing array may + // be overwritten by later calls. + line = append(line, slice...) + return +} + +// readSlice is like ReadBytes but returns a reference to internal buffer data. +func (b *Buffer) readSlice(delim byte) (line []byte, err error) { + i := bytes.IndexByte(b.buf[b.off:], delim) + end := b.off + i + 1 + if i < 0 { + end = len(b.buf) + err = io.EOF + } + line = b.buf[b.off:end] + b.off = end + return line, err +} + +// NewBuffer creates and initializes a new Buffer using buf as its initial +// contents. It is intended to prepare a Buffer to read existing data. It +// can also be used to size the internal buffer for writing. To do that, +// buf should have the desired capacity but a length of zero. +// +// In most cases, new(Buffer) (or just declaring a Buffer variable) is +// sufficient to initialize a Buffer. +func NewBuffer(buf []byte) *Buffer { return &Buffer{buf: buf} } diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go new file mode 100644 index 0000000000..2b8453d759 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_pool.go @@ -0,0 +1,238 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +import ( + "fmt" + "sync" + "sync/atomic" + "time" +) + +type buffer struct { + b []byte + miss int +} + +// BufferPool is a 'buffer pool'. +type BufferPool struct { + pool [6]chan []byte + size [5]uint32 + sizeMiss [5]uint32 + sizeHalf [5]uint32 + baseline [4]int + baseline0 int + + mu sync.RWMutex + closed bool + closeC chan struct{} + + get uint32 + put uint32 + half uint32 + less uint32 + equal uint32 + greater uint32 + miss uint32 +} + +func (p *BufferPool) poolNum(n int) int { + if n <= p.baseline0 && n > p.baseline0/2 { + return 0 + } + for i, x := range p.baseline { + if n <= x { + return i + 1 + } + } + return len(p.baseline) + 1 +} + +// Get returns buffer with length of n. +func (p *BufferPool) Get(n int) []byte { + if p == nil { + return make([]byte, n) + } + + p.mu.RLock() + defer p.mu.RUnlock() + + if p.closed { + return make([]byte, n) + } + + atomic.AddUint32(&p.get, 1) + + poolNum := p.poolNum(n) + pool := p.pool[poolNum] + if poolNum == 0 { + // Fast path. + select { + case b := <-pool: + switch { + case cap(b) > n: + if cap(b)-n >= n { + atomic.AddUint32(&p.half, 1) + select { + case pool <- b: + default: + } + return make([]byte, n) + } else { + atomic.AddUint32(&p.less, 1) + return b[:n] + } + case cap(b) == n: + atomic.AddUint32(&p.equal, 1) + return b[:n] + default: + atomic.AddUint32(&p.greater, 1) + } + default: + atomic.AddUint32(&p.miss, 1) + } + + return make([]byte, n, p.baseline0) + } else { + sizePtr := &p.size[poolNum-1] + + select { + case b := <-pool: + switch { + case cap(b) > n: + if cap(b)-n >= n { + atomic.AddUint32(&p.half, 1) + sizeHalfPtr := &p.sizeHalf[poolNum-1] + if atomic.AddUint32(sizeHalfPtr, 1) == 20 { + atomic.StoreUint32(sizePtr, uint32(cap(b)/2)) + atomic.StoreUint32(sizeHalfPtr, 0) + } else { + select { + case pool <- b: + default: + } + } + return make([]byte, n) + } else { + atomic.AddUint32(&p.less, 1) + return b[:n] + } + case cap(b) == n: + atomic.AddUint32(&p.equal, 1) + return b[:n] + default: + atomic.AddUint32(&p.greater, 1) + if uint32(cap(b)) >= atomic.LoadUint32(sizePtr) { + select { + case pool <- b: + default: + } + } + } + default: + atomic.AddUint32(&p.miss, 1) + } + + if size := atomic.LoadUint32(sizePtr); uint32(n) > size { + if size == 0 { + atomic.CompareAndSwapUint32(sizePtr, 0, uint32(n)) + } else { + sizeMissPtr := &p.sizeMiss[poolNum-1] + if atomic.AddUint32(sizeMissPtr, 1) == 20 { + atomic.StoreUint32(sizePtr, uint32(n)) + atomic.StoreUint32(sizeMissPtr, 0) + } + } + return make([]byte, n) + } else { + return make([]byte, n, size) + } + } +} + +// Put adds given buffer to the pool. +func (p *BufferPool) Put(b []byte) { + if p == nil { + return + } + + p.mu.RLock() + defer p.mu.RUnlock() + + if p.closed { + return + } + + atomic.AddUint32(&p.put, 1) + + pool := p.pool[p.poolNum(cap(b))] + select { + case pool <- b: + default: + } + +} + +func (p *BufferPool) Close() { + if p == nil { + return + } + + p.mu.Lock() + if !p.closed { + p.closed = true + p.closeC <- struct{}{} + } + p.mu.Unlock() +} + +func (p *BufferPool) String() string { + if p == nil { + return "" + } + + return fmt.Sprintf("BufferPool{B·%d Z·%v Zm·%v Zh·%v G·%d P·%d H·%d <·%d =·%d >·%d M·%d}", + p.baseline0, p.size, p.sizeMiss, p.sizeHalf, p.get, p.put, p.half, p.less, p.equal, p.greater, p.miss) +} + +func (p *BufferPool) drain() { + ticker := time.NewTicker(2 * time.Second) + for { + select { + case <-ticker.C: + for _, ch := range p.pool { + select { + case <-ch: + default: + } + } + case <-p.closeC: + close(p.closeC) + for _, ch := range p.pool { + close(ch) + } + return + } + } +} + +// NewBufferPool creates a new initialized 'buffer pool'. +func NewBufferPool(baseline int) *BufferPool { + if baseline <= 0 { + panic("baseline can't be <= 0") + } + p := &BufferPool{ + baseline0: baseline, + baseline: [...]int{baseline / 4, baseline / 2, baseline * 2, baseline * 4}, + closeC: make(chan struct{}, 1), + } + for i, cap := range []int{2, 2, 4, 4, 2, 1} { + p.pool[i] = make(chan []byte, cap) + } + go p.drain() + return p +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go new file mode 100644 index 0000000000..87d96739c4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/buffer_test.go @@ -0,0 +1,369 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package util + +import ( + "bytes" + "io" + "math/rand" + "runtime" + "testing" +) + +const N = 10000 // make this bigger for a larger (and slower) test +var data string // test data for write tests +var testBytes []byte // test data; same as data but as a slice. + +func init() { + testBytes = make([]byte, N) + for i := 0; i < N; i++ { + testBytes[i] = 'a' + byte(i%26) + } + data = string(testBytes) +} + +// Verify that contents of buf match the string s. +func check(t *testing.T, testname string, buf *Buffer, s string) { + bytes := buf.Bytes() + str := buf.String() + if buf.Len() != len(bytes) { + t.Errorf("%s: buf.Len() == %d, len(buf.Bytes()) == %d", testname, buf.Len(), len(bytes)) + } + + if buf.Len() != len(str) { + t.Errorf("%s: buf.Len() == %d, len(buf.String()) == %d", testname, buf.Len(), len(str)) + } + + if buf.Len() != len(s) { + t.Errorf("%s: buf.Len() == %d, len(s) == %d", testname, buf.Len(), len(s)) + } + + if string(bytes) != s { + t.Errorf("%s: string(buf.Bytes()) == %q, s == %q", testname, string(bytes), s) + } +} + +// Fill buf through n writes of byte slice fub. +// The initial contents of buf corresponds to the string s; +// the result is the final contents of buf returned as a string. +func fillBytes(t *testing.T, testname string, buf *Buffer, s string, n int, fub []byte) string { + check(t, testname+" (fill 1)", buf, s) + for ; n > 0; n-- { + m, err := buf.Write(fub) + if m != len(fub) { + t.Errorf(testname+" (fill 2): m == %d, expected %d", m, len(fub)) + } + if err != nil { + t.Errorf(testname+" (fill 3): err should always be nil, found err == %s", err) + } + s += string(fub) + check(t, testname+" (fill 4)", buf, s) + } + return s +} + +func TestNewBuffer(t *testing.T) { + buf := NewBuffer(testBytes) + check(t, "NewBuffer", buf, data) +} + +// Empty buf through repeated reads into fub. +// The initial contents of buf corresponds to the string s. +func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) { + check(t, testname+" (empty 1)", buf, s) + + for { + n, err := buf.Read(fub) + if n == 0 { + break + } + if err != nil { + t.Errorf(testname+" (empty 2): err should always be nil, found err == %s", err) + } + s = s[n:] + check(t, testname+" (empty 3)", buf, s) + } + + check(t, testname+" (empty 4)", buf, "") +} + +func TestBasicOperations(t *testing.T) { + var buf Buffer + + for i := 0; i < 5; i++ { + check(t, "TestBasicOperations (1)", &buf, "") + + buf.Reset() + check(t, "TestBasicOperations (2)", &buf, "") + + buf.Truncate(0) + check(t, "TestBasicOperations (3)", &buf, "") + + n, err := buf.Write([]byte(data[0:1])) + if n != 1 { + t.Errorf("wrote 1 byte, but n == %d", n) + } + if err != nil { + t.Errorf("err should always be nil, but err == %s", err) + } + check(t, "TestBasicOperations (4)", &buf, "a") + + buf.WriteByte(data[1]) + check(t, "TestBasicOperations (5)", &buf, "ab") + + n, err = buf.Write([]byte(data[2:26])) + if n != 24 { + t.Errorf("wrote 25 bytes, but n == %d", n) + } + check(t, "TestBasicOperations (6)", &buf, string(data[0:26])) + + buf.Truncate(26) + check(t, "TestBasicOperations (7)", &buf, string(data[0:26])) + + buf.Truncate(20) + check(t, "TestBasicOperations (8)", &buf, string(data[0:20])) + + empty(t, "TestBasicOperations (9)", &buf, string(data[0:20]), make([]byte, 5)) + empty(t, "TestBasicOperations (10)", &buf, "", make([]byte, 100)) + + buf.WriteByte(data[1]) + c, err := buf.ReadByte() + if err != nil { + t.Error("ReadByte unexpected eof") + } + if c != data[1] { + t.Errorf("ReadByte wrong value c=%v", c) + } + c, err = buf.ReadByte() + if err == nil { + t.Error("ReadByte unexpected not eof") + } + } +} + +func TestLargeByteWrites(t *testing.T) { + var buf Buffer + limit := 30 + if testing.Short() { + limit = 9 + } + for i := 3; i < limit; i += 3 { + s := fillBytes(t, "TestLargeWrites (1)", &buf, "", 5, testBytes) + empty(t, "TestLargeByteWrites (2)", &buf, s, make([]byte, len(data)/i)) + } + check(t, "TestLargeByteWrites (3)", &buf, "") +} + +func TestLargeByteReads(t *testing.T) { + var buf Buffer + for i := 3; i < 30; i += 3 { + s := fillBytes(t, "TestLargeReads (1)", &buf, "", 5, testBytes[0:len(testBytes)/i]) + empty(t, "TestLargeReads (2)", &buf, s, make([]byte, len(data))) + } + check(t, "TestLargeByteReads (3)", &buf, "") +} + +func TestMixedReadsAndWrites(t *testing.T) { + var buf Buffer + s := "" + for i := 0; i < 50; i++ { + wlen := rand.Intn(len(data)) + s = fillBytes(t, "TestMixedReadsAndWrites (1)", &buf, s, 1, testBytes[0:wlen]) + rlen := rand.Intn(len(data)) + fub := make([]byte, rlen) + n, _ := buf.Read(fub) + s = s[n:] + } + empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len())) +} + +func TestNil(t *testing.T) { + var b *Buffer + if b.String() != "" { + t.Errorf("expected ; got %q", b.String()) + } +} + +func TestReadFrom(t *testing.T) { + var buf Buffer + for i := 3; i < 30; i += 3 { + s := fillBytes(t, "TestReadFrom (1)", &buf, "", 5, testBytes[0:len(testBytes)/i]) + var b Buffer + b.ReadFrom(&buf) + empty(t, "TestReadFrom (2)", &b, s, make([]byte, len(data))) + } +} + +func TestWriteTo(t *testing.T) { + var buf Buffer + for i := 3; i < 30; i += 3 { + s := fillBytes(t, "TestWriteTo (1)", &buf, "", 5, testBytes[0:len(testBytes)/i]) + var b Buffer + buf.WriteTo(&b) + empty(t, "TestWriteTo (2)", &b, s, make([]byte, len(data))) + } +} + +func TestNext(t *testing.T) { + b := []byte{0, 1, 2, 3, 4} + tmp := make([]byte, 5) + for i := 0; i <= 5; i++ { + for j := i; j <= 5; j++ { + for k := 0; k <= 6; k++ { + // 0 <= i <= j <= 5; 0 <= k <= 6 + // Check that if we start with a buffer + // of length j at offset i and ask for + // Next(k), we get the right bytes. + buf := NewBuffer(b[0:j]) + n, _ := buf.Read(tmp[0:i]) + if n != i { + t.Fatalf("Read %d returned %d", i, n) + } + bb := buf.Next(k) + want := k + if want > j-i { + want = j - i + } + if len(bb) != want { + t.Fatalf("in %d,%d: len(Next(%d)) == %d", i, j, k, len(bb)) + } + for l, v := range bb { + if v != byte(l+i) { + t.Fatalf("in %d,%d: Next(%d)[%d] = %d, want %d", i, j, k, l, v, l+i) + } + } + } + } + } +} + +var readBytesTests = []struct { + buffer string + delim byte + expected []string + err error +}{ + {"", 0, []string{""}, io.EOF}, + {"a\x00", 0, []string{"a\x00"}, nil}, + {"abbbaaaba", 'b', []string{"ab", "b", "b", "aaab"}, nil}, + {"hello\x01world", 1, []string{"hello\x01"}, nil}, + {"foo\nbar", 0, []string{"foo\nbar"}, io.EOF}, + {"alpha\nbeta\ngamma\n", '\n', []string{"alpha\n", "beta\n", "gamma\n"}, nil}, + {"alpha\nbeta\ngamma", '\n', []string{"alpha\n", "beta\n", "gamma"}, io.EOF}, +} + +func TestReadBytes(t *testing.T) { + for _, test := range readBytesTests { + buf := NewBuffer([]byte(test.buffer)) + var err error + for _, expected := range test.expected { + var bytes []byte + bytes, err = buf.ReadBytes(test.delim) + if string(bytes) != expected { + t.Errorf("expected %q, got %q", expected, bytes) + } + if err != nil { + break + } + } + if err != test.err { + t.Errorf("expected error %v, got %v", test.err, err) + } + } +} + +func TestGrow(t *testing.T) { + x := []byte{'x'} + y := []byte{'y'} + tmp := make([]byte, 72) + for _, startLen := range []int{0, 100, 1000, 10000, 100000} { + xBytes := bytes.Repeat(x, startLen) + for _, growLen := range []int{0, 100, 1000, 10000, 100000} { + buf := NewBuffer(xBytes) + // If we read, this affects buf.off, which is good to test. + readBytes, _ := buf.Read(tmp) + buf.Grow(growLen) + yBytes := bytes.Repeat(y, growLen) + // Check no allocation occurs in write, as long as we're single-threaded. + var m1, m2 runtime.MemStats + runtime.ReadMemStats(&m1) + buf.Write(yBytes) + runtime.ReadMemStats(&m2) + if runtime.GOMAXPROCS(-1) == 1 && m1.Mallocs != m2.Mallocs { + t.Errorf("allocation occurred during write") + } + // Check that buffer has correct data. + if !bytes.Equal(buf.Bytes()[0:startLen-readBytes], xBytes[readBytes:]) { + t.Errorf("bad initial data at %d %d", startLen, growLen) + } + if !bytes.Equal(buf.Bytes()[startLen-readBytes:startLen-readBytes+growLen], yBytes) { + t.Errorf("bad written data at %d %d", startLen, growLen) + } + } + } +} + +// Was a bug: used to give EOF reading empty slice at EOF. +func TestReadEmptyAtEOF(t *testing.T) { + b := new(Buffer) + slice := make([]byte, 0) + n, err := b.Read(slice) + if err != nil { + t.Errorf("read error: %v", err) + } + if n != 0 { + t.Errorf("wrong count; got %d want 0", n) + } +} + +// Tests that we occasionally compact. Issue 5154. +func TestBufferGrowth(t *testing.T) { + var b Buffer + buf := make([]byte, 1024) + b.Write(buf[0:1]) + var cap0 int + for i := 0; i < 5<<10; i++ { + b.Write(buf) + b.Read(buf) + if i == 0 { + cap0 = cap(b.buf) + } + } + cap1 := cap(b.buf) + // (*Buffer).grow allows for 2x capacity slop before sliding, + // so set our error threshold at 3x. + if cap1 > cap0*3 { + t.Errorf("buffer cap = %d; too big (grew from %d)", cap1, cap0) + } +} + +// From Issue 5154. +func BenchmarkBufferNotEmptyWriteRead(b *testing.B) { + buf := make([]byte, 1024) + for i := 0; i < b.N; i++ { + var b Buffer + b.Write(buf[0:1]) + for i := 0; i < 5<<10; i++ { + b.Write(buf) + b.Read(buf) + } + } +} + +// Check that we don't compact too often. From Issue 5154. +func BenchmarkBufferFullSmallReads(b *testing.B) { + buf := make([]byte, 1024) + for i := 0; i < b.N; i++ { + var b Buffer + b.Write(buf) + for b.Len()+20 < cap(b.buf) { + b.Write(buf[:10]) + } + for i := 0; i < 5<<10; i++ { + b.Read(buf[:1]) + b.Write(buf[:1]) + } + } +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/crc32.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/crc32.go new file mode 100644 index 0000000000..631c9d6109 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/crc32.go @@ -0,0 +1,30 @@ +// Copyright 2011 The LevelDB-Go Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +import ( + "hash/crc32" +) + +var table = crc32.MakeTable(crc32.Castagnoli) + +// CRC is a CRC-32 checksum computed using Castagnoli's polynomial. +type CRC uint32 + +// NewCRC creates a new crc based on the given bytes. +func NewCRC(b []byte) CRC { + return CRC(0).Update(b) +} + +// Update updates the crc with the given bytes. +func (c CRC) Update(b []byte) CRC { + return CRC(crc32.Update(uint32(c), table, b)) +} + +// Value returns a masked crc. +func (c CRC) Value() uint32 { + return uint32(c>>15|c<<17) + 0xa282ead8 +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/hash.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/hash.go new file mode 100644 index 0000000000..54903660ff --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/hash.go @@ -0,0 +1,48 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +import ( + "bytes" + "encoding/binary" +) + +// Hash return hash of the given data. +func Hash(data []byte, seed uint32) uint32 { + // Similar to murmur hash + var m uint32 = 0xc6a4a793 + var r uint32 = 24 + h := seed ^ (uint32(len(data)) * m) + + buf := bytes.NewBuffer(data) + for buf.Len() >= 4 { + var w uint32 + binary.Read(buf, binary.LittleEndian, &w) + h += w + h *= m + h ^= (h >> 16) + } + + rest := buf.Bytes() + switch len(rest) { + default: + panic("not reached") + case 3: + h += uint32(rest[2]) << 16 + fallthrough + case 2: + h += uint32(rest[1]) << 8 + fallthrough + case 1: + h += uint32(rest[0]) + h *= m + h ^= (h >> r) + case 0: + } + + return h +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool.go new file mode 100644 index 0000000000..1f7fdd41fe --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool.go @@ -0,0 +1,21 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build go1.3 + +package util + +import ( + "sync" +) + +type Pool struct { + sync.Pool +} + +func NewPool(cap int) *Pool { + return &Pool{} +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool_legacy.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool_legacy.go new file mode 100644 index 0000000000..27b8d03be9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/pool_legacy.go @@ -0,0 +1,33 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +build !go1.3 + +package util + +type Pool struct { + pool chan interface{} +} + +func (p *Pool) Get() interface{} { + select { + case x := <-p.pool: + return x + default: + return nil + } +} + +func (p *Pool) Put(x interface{}) { + select { + case p.pool <- x: + default: + } +} + +func NewPool(cap int) *Pool { + return &Pool{pool: make(chan interface{}, cap)} +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/range.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/range.go new file mode 100644 index 0000000000..85159583d2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/range.go @@ -0,0 +1,32 @@ +// Copyright (c) 2014, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package util + +// Range is a key range. +type Range struct { + // Start of the key range, include in the range. + Start []byte + + // Limit of the key range, not include in the range. + Limit []byte +} + +// BytesPrefix returns key range that satisfy the given prefix. +// This only applicable for the standard 'bytes comparer'. +func BytesPrefix(prefix []byte) *Range { + var limit []byte + for i := len(prefix) - 1; i >= 0; i-- { + c := prefix[i] + if c < 0xff { + limit = make([]byte, i+1) + copy(limit, prefix) + limit[i] = c + 1 + break + } + } + return &Range{prefix, limit} +} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/util.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/util.go new file mode 100644 index 0000000000..f35976865b --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/util/util.go @@ -0,0 +1,73 @@ +// Copyright (c) 2013, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package util provides utilities used throughout leveldb. +package util + +import ( + "errors" +) + +var ( + ErrReleased = errors.New("leveldb: resource already relesed") + ErrHasReleaser = errors.New("leveldb: releaser already defined") +) + +// Releaser is the interface that wraps the basic Release method. +type Releaser interface { + // Release releases associated resources. Release should always success + // and can be called multipe times without causing error. + Release() +} + +// ReleaseSetter is the interface that wraps the basic SetReleaser method. +type ReleaseSetter interface { + // SetReleaser associates the given releaser to the resources. The + // releaser will be called once coresponding resources released. + // Calling SetReleaser with nil will clear the releaser. + // + // This will panic if a releaser already present or coresponding + // resource is already released. Releaser should be cleared first + // before assigned a new one. + SetReleaser(releaser Releaser) +} + +// BasicReleaser provides basic implementation of Releaser and ReleaseSetter. +type BasicReleaser struct { + releaser Releaser + released bool +} + +// Released returns whether Release method already called. +func (r *BasicReleaser) Released() bool { + return r.released +} + +// Release implements Releaser.Release. +func (r *BasicReleaser) Release() { + if !r.released { + if r.releaser != nil { + r.releaser.Release() + r.releaser = nil + } + r.released = true + } +} + +// SetReleaser implements ReleaseSetter.SetReleaser. +func (r *BasicReleaser) SetReleaser(releaser Releaser) { + if r.released { + panic(ErrReleased) + } + if r.releaser != nil && releaser != nil { + panic(ErrHasReleaser) + } + r.releaser = releaser +} + +type NoopReleaser struct{} + +func (NoopReleaser) Release() {} diff --git a/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/version.go b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/version.go new file mode 100644 index 0000000000..5ab7b53d31 --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb/version.go @@ -0,0 +1,436 @@ +// Copyright (c) 2012, Suryandaru Triandana +// All rights reserved. +// +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package leveldb + +import ( + "sync/atomic" + "unsafe" + + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type tSet struct { + level int + table *tFile +} + +type version struct { + s *session + + tables []tFiles + + // Level that should be compacted next and its compaction score. + // Score < 1 means compaction is not strictly needed. These fields + // are initialized by computeCompaction() + cLevel int + cScore float64 + + cSeek unsafe.Pointer + + ref int + // Succeeding version. + next *version +} + +func newVersion(s *session) *version { + return &version{s: s, tables: make([]tFiles, s.o.GetNumLevel())} +} + +func (v *version) releaseNB() { + v.ref-- + if v.ref > 0 { + return + } + if v.ref < 0 { + panic("negative version ref") + } + + tables := make(map[uint64]bool) + for _, tt := range v.next.tables { + for _, t := range tt { + num := t.file.Num() + tables[num] = true + } + } + + for _, tt := range v.tables { + for _, t := range tt { + num := t.file.Num() + if _, ok := tables[num]; !ok { + v.s.tops.remove(t) + } + } + } + + v.next.releaseNB() + v.next = nil +} + +func (v *version) release() { + v.s.vmu.Lock() + v.releaseNB() + v.s.vmu.Unlock() +} + +func (v *version) walkOverlapping(ikey iKey, f func(level int, t *tFile) bool, lf func(level int) bool) { + ukey := ikey.ukey() + + // Walk tables level-by-level. + for level, tables := range v.tables { + if len(tables) == 0 { + continue + } + + if level == 0 { + // Level-0 files may overlap each other. Find all files that + // overlap ukey. + for _, t := range tables { + if t.overlaps(v.s.icmp, ukey, ukey) { + if !f(level, t) { + return + } + } + } + } else { + if i := tables.searchMax(v.s.icmp, ikey); i < len(tables) { + t := tables[i] + if v.s.icmp.uCompare(ukey, t.imin.ukey()) >= 0 { + if !f(level, t) { + return + } + } + } + } + + if lf != nil && !lf(level) { + return + } + } +} + +func (v *version) get(ikey iKey, ro *opt.ReadOptions, noValue bool) (value []byte, tcomp bool, err error) { + ukey := ikey.ukey() + + var ( + tset *tSet + tseek bool + + // Level-0. + zfound bool + zseq uint64 + zkt kType + zval []byte + ) + + err = ErrNotFound + + // Since entries never hope across level, finding key/value + // in smaller level make later levels irrelevant. + v.walkOverlapping(ikey, func(level int, t *tFile) bool { + if !tseek { + if tset == nil { + tset = &tSet{level, t} + } else if tset.table.consumeSeek() <= 0 { + tseek = true + tcomp = atomic.CompareAndSwapPointer(&v.cSeek, nil, unsafe.Pointer(tset)) + } + } + + var ( + fikey, fval []byte + ferr error + ) + if noValue { + fikey, ferr = v.s.tops.findKey(t, ikey, ro) + } else { + fikey, fval, ferr = v.s.tops.find(t, ikey, ro) + } + switch ferr { + case nil: + case ErrNotFound: + return true + default: + err = ferr + return false + } + + if fukey, fseq, fkt, fkerr := parseIkey(fikey); fkerr == nil { + if v.s.icmp.uCompare(ukey, fukey) == 0 { + if level == 0 { + if fseq >= zseq { + zfound = true + zseq = fseq + zkt = fkt + zval = fval + } + } else { + switch fkt { + case ktVal: + value = fval + err = nil + case ktDel: + default: + panic("leveldb: invalid iKey type") + } + return false + } + } + } else { + err = fkerr + return false + } + + return true + }, func(level int) bool { + if zfound { + switch zkt { + case ktVal: + value = zval + err = nil + case ktDel: + default: + panic("leveldb: invalid iKey type") + } + return false + } + + return true + }) + + return +} + +func (v *version) getIterators(slice *util.Range, ro *opt.ReadOptions) (its []iterator.Iterator) { + // Merge all level zero files together since they may overlap + for _, t := range v.tables[0] { + it := v.s.tops.newIterator(t, slice, ro) + its = append(its, it) + } + + strict := opt.GetStrict(v.s.o.Options, ro, opt.StrictReader) + for _, tables := range v.tables[1:] { + if len(tables) == 0 { + continue + } + + it := iterator.NewIndexedIterator(tables.newIndexIterator(v.s.tops, v.s.icmp, slice, ro), strict) + its = append(its, it) + } + + return +} + +func (v *version) newStaging() *versionStaging { + return &versionStaging{base: v, tables: make([]tablesScratch, v.s.o.GetNumLevel())} +} + +// Spawn a new version based on this version. +func (v *version) spawn(r *sessionRecord) *version { + staging := v.newStaging() + staging.commit(r) + return staging.finish() +} + +func (v *version) fillRecord(r *sessionRecord) { + for level, ts := range v.tables { + for _, t := range ts { + r.addTableFile(level, t) + } + } +} + +func (v *version) tLen(level int) int { + return len(v.tables[level]) +} + +func (v *version) offsetOf(ikey iKey) (n uint64, err error) { + for level, tables := range v.tables { + for _, t := range tables { + if v.s.icmp.Compare(t.imax, ikey) <= 0 { + // Entire file is before "ikey", so just add the file size + n += t.size + } else if v.s.icmp.Compare(t.imin, ikey) > 0 { + // Entire file is after "ikey", so ignore + if level > 0 { + // Files other than level 0 are sorted by meta->min, so + // no further files in this level will contain data for + // "ikey". + break + } + } else { + // "ikey" falls in the range for this table. Add the + // approximate offset of "ikey" within the table. + var nn uint64 + nn, err = v.s.tops.offsetOf(t, ikey) + if err != nil { + return 0, err + } + n += nn + } + } + } + + return +} + +func (v *version) pickLevel(umin, umax []byte) (level int) { + if !v.tables[0].overlaps(v.s.icmp, umin, umax, true) { + var overlaps tFiles + maxLevel := v.s.o.GetMaxMemCompationLevel() + for ; level < maxLevel; level++ { + if v.tables[level+1].overlaps(v.s.icmp, umin, umax, false) { + break + } + overlaps = v.tables[level+2].getOverlaps(overlaps, v.s.icmp, umin, umax, false) + if overlaps.size() > uint64(v.s.o.GetCompactionGPOverlaps(level)) { + break + } + } + } + + return +} + +func (v *version) computeCompaction() { + // Precomputed best level for next compaction + var bestLevel int = -1 + var bestScore float64 = -1 + + for level, tables := range v.tables { + var score float64 + if level == 0 { + // We treat level-0 specially by bounding the number of files + // instead of number of bytes for two reasons: + // + // (1) With larger write-buffer sizes, it is nice not to do too + // many level-0 compactions. + // + // (2) The files in level-0 are merged on every read and + // therefore we wish to avoid too many files when the individual + // file size is small (perhaps because of a small write-buffer + // setting, or very high compression ratios, or lots of + // overwrites/deletions). + score = float64(len(tables)) / float64(v.s.o.GetCompactionL0Trigger()) + } else { + score = float64(tables.size()) / float64(v.s.o.GetCompactionTotalSize(level)) + } + + if score > bestScore { + bestLevel = level + bestScore = score + } + } + + v.cLevel = bestLevel + v.cScore = bestScore +} + +func (v *version) needCompaction() bool { + return v.cScore >= 1 || atomic.LoadPointer(&v.cSeek) != nil +} + +type tablesScratch struct { + added map[uint64]atRecord + deleted map[uint64]struct{} +} + +type versionStaging struct { + base *version + tables []tablesScratch +} + +func (p *versionStaging) commit(r *sessionRecord) { + // Deleted tables. + for _, r := range r.deletedTables { + tm := &(p.tables[r.level]) + + if len(p.base.tables[r.level]) > 0 { + if tm.deleted == nil { + tm.deleted = make(map[uint64]struct{}) + } + tm.deleted[r.num] = struct{}{} + } + + if tm.added != nil { + delete(tm.added, r.num) + } + } + + // New tables. + for _, r := range r.addedTables { + tm := &(p.tables[r.level]) + + if tm.added == nil { + tm.added = make(map[uint64]atRecord) + } + tm.added[r.num] = r + + if tm.deleted != nil { + delete(tm.deleted, r.num) + } + } +} + +func (p *versionStaging) finish() *version { + // Build new version. + nv := newVersion(p.base.s) + for level, tm := range p.tables { + btables := p.base.tables[level] + + n := len(btables) + len(tm.added) - len(tm.deleted) + if n < 0 { + n = 0 + } + nt := make(tFiles, 0, n) + + // Base tables. + for _, t := range btables { + if _, ok := tm.deleted[t.file.Num()]; ok { + continue + } + if _, ok := tm.added[t.file.Num()]; ok { + continue + } + nt = append(nt, t) + } + + // New tables. + for _, r := range tm.added { + nt = append(nt, p.base.s.tableFileFromRecord(r)) + } + + // Sort tables. + if level == 0 { + nt.sortByNum() + } else { + nt.sortByKey(p.base.s.icmp) + } + nv.tables[level] = nt + } + + // Compute compaction score for new version. + nv.computeCompaction() + + return nv +} + +type versionReleaser struct { + v *version + once bool +} + +func (vr *versionReleaser) Release() { + v := vr.v + v.s.vmu.Lock() + if !vr.once { + v.releaseNB() + vr.once = true + } + v.s.vmu.Unlock() +} diff --git a/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/decode.go b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/decode.go new file mode 100644 index 0000000000..d93c1b9dbf --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/decode.go @@ -0,0 +1,124 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "encoding/binary" + "errors" +) + +// ErrCorrupt reports that the input is invalid. +var ErrCorrupt = errors.New("snappy: corrupt input") + +// DecodedLen returns the length of the decoded block. +func DecodedLen(src []byte) (int, error) { + v, _, err := decodedLen(src) + return v, err +} + +// decodedLen returns the length of the decoded block and the number of bytes +// that the length header occupied. +func decodedLen(src []byte) (blockLen, headerLen int, err error) { + v, n := binary.Uvarint(src) + if n == 0 { + return 0, 0, ErrCorrupt + } + if uint64(int(v)) != v { + return 0, 0, errors.New("snappy: decoded block is too large") + } + return int(v), n, nil +} + +// Decode returns the decoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire decoded block. +// Otherwise, a newly allocated slice will be returned. +// It is valid to pass a nil dst. +func Decode(dst, src []byte) ([]byte, error) { + dLen, s, err := decodedLen(src) + if err != nil { + return nil, err + } + if len(dst) < dLen { + dst = make([]byte, dLen) + } + + var d, offset, length int + for s < len(src) { + switch src[s] & 0x03 { + case tagLiteral: + x := uint(src[s] >> 2) + switch { + case x < 60: + s += 1 + case x == 60: + s += 2 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-1]) + case x == 61: + s += 3 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-2]) | uint(src[s-1])<<8 + case x == 62: + s += 4 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-3]) | uint(src[s-2])<<8 | uint(src[s-1])<<16 + case x == 63: + s += 5 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-4]) | uint(src[s-3])<<8 | uint(src[s-2])<<16 | uint(src[s-1])<<24 + } + length = int(x + 1) + if length <= 0 { + return nil, errors.New("snappy: unsupported literal length") + } + if length > len(dst)-d || length > len(src)-s { + return nil, ErrCorrupt + } + copy(dst[d:], src[s:s+length]) + d += length + s += length + continue + + case tagCopy1: + s += 2 + if s > len(src) { + return nil, ErrCorrupt + } + length = 4 + int(src[s-2])>>2&0x7 + offset = int(src[s-2])&0xe0<<3 | int(src[s-1]) + + case tagCopy2: + s += 3 + if s > len(src) { + return nil, ErrCorrupt + } + length = 1 + int(src[s-3])>>2 + offset = int(src[s-2]) | int(src[s-1])<<8 + + case tagCopy4: + return nil, errors.New("snappy: unsupported COPY_4 tag") + } + + end := d + length + if offset > d || end > len(dst) { + return nil, ErrCorrupt + } + for ; d < end; d++ { + dst[d] = dst[d-offset] + } + } + if d != dLen { + return nil, ErrCorrupt + } + return dst[:d], nil +} diff --git a/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/encode.go b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/encode.go new file mode 100644 index 0000000000..b2371db11c --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/encode.go @@ -0,0 +1,174 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "encoding/binary" +) + +// We limit how far copy back-references can go, the same as the C++ code. +const maxOffset = 1 << 15 + +// emitLiteral writes a literal chunk and returns the number of bytes written. +func emitLiteral(dst, lit []byte) int { + i, n := 0, uint(len(lit)-1) + switch { + case n < 60: + dst[0] = uint8(n)<<2 | tagLiteral + i = 1 + case n < 1<<8: + dst[0] = 60<<2 | tagLiteral + dst[1] = uint8(n) + i = 2 + case n < 1<<16: + dst[0] = 61<<2 | tagLiteral + dst[1] = uint8(n) + dst[2] = uint8(n >> 8) + i = 3 + case n < 1<<24: + dst[0] = 62<<2 | tagLiteral + dst[1] = uint8(n) + dst[2] = uint8(n >> 8) + dst[3] = uint8(n >> 16) + i = 4 + case int64(n) < 1<<32: + dst[0] = 63<<2 | tagLiteral + dst[1] = uint8(n) + dst[2] = uint8(n >> 8) + dst[3] = uint8(n >> 16) + dst[4] = uint8(n >> 24) + i = 5 + default: + panic("snappy: source buffer is too long") + } + if copy(dst[i:], lit) != len(lit) { + panic("snappy: destination buffer is too short") + } + return i + len(lit) +} + +// emitCopy writes a copy chunk and returns the number of bytes written. +func emitCopy(dst []byte, offset, length int) int { + i := 0 + for length > 0 { + x := length - 4 + if 0 <= x && x < 1<<3 && offset < 1<<11 { + dst[i+0] = uint8(offset>>8)&0x07<<5 | uint8(x)<<2 | tagCopy1 + dst[i+1] = uint8(offset) + i += 2 + break + } + + x = length + if x > 1<<6 { + x = 1 << 6 + } + dst[i+0] = uint8(x-1)<<2 | tagCopy2 + dst[i+1] = uint8(offset) + dst[i+2] = uint8(offset >> 8) + i += 3 + length -= x + } + return i +} + +// Encode returns the encoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire encoded block. +// Otherwise, a newly allocated slice will be returned. +// It is valid to pass a nil dst. +func Encode(dst, src []byte) ([]byte, error) { + if n := MaxEncodedLen(len(src)); len(dst) < n { + dst = make([]byte, n) + } + + // The block starts with the varint-encoded length of the decompressed bytes. + d := binary.PutUvarint(dst, uint64(len(src))) + + // Return early if src is short. + if len(src) <= 4 { + if len(src) != 0 { + d += emitLiteral(dst[d:], src) + } + return dst[:d], nil + } + + // Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive. + const maxTableSize = 1 << 14 + shift, tableSize := uint(32-8), 1<<8 + for tableSize < maxTableSize && tableSize < len(src) { + shift-- + tableSize *= 2 + } + var table [maxTableSize]int + + // Iterate over the source bytes. + var ( + s int // The iterator position. + t int // The last position with the same hash as s. + lit int // The start position of any pending literal bytes. + ) + for s+3 < len(src) { + // Update the hash table. + b0, b1, b2, b3 := src[s], src[s+1], src[s+2], src[s+3] + h := uint32(b0) | uint32(b1)<<8 | uint32(b2)<<16 | uint32(b3)<<24 + p := &table[(h*0x1e35a7bd)>>shift] + // We need to to store values in [-1, inf) in table. To save + // some initialization time, (re)use the table's zero value + // and shift the values against this zero: add 1 on writes, + // subtract 1 on reads. + t, *p = *p-1, s+1 + // If t is invalid or src[s:s+4] differs from src[t:t+4], accumulate a literal byte. + if t < 0 || s-t >= maxOffset || b0 != src[t] || b1 != src[t+1] || b2 != src[t+2] || b3 != src[t+3] { + s++ + continue + } + // Otherwise, we have a match. First, emit any pending literal bytes. + if lit != s { + d += emitLiteral(dst[d:], src[lit:s]) + } + // Extend the match to be as long as possible. + s0 := s + s, t = s+4, t+4 + for s < len(src) && src[s] == src[t] { + s++ + t++ + } + // Emit the copied bytes. + d += emitCopy(dst[d:], s-t, s-s0) + lit = s + } + + // Emit any final pending literal bytes and return. + if lit != len(src) { + d += emitLiteral(dst[d:], src[lit:]) + } + return dst[:d], nil +} + +// MaxEncodedLen returns the maximum length of a snappy block, given its +// uncompressed length. +func MaxEncodedLen(srcLen int) int { + // Compressed data can be defined as: + // compressed := item* literal* + // item := literal* copy + // + // The trailing literal sequence has a space blowup of at most 62/60 + // since a literal of length 60 needs one tag byte + one extra byte + // for length information. + // + // Item blowup is trickier to measure. Suppose the "copy" op copies + // 4 bytes of data. Because of a special check in the encoding code, + // we produce a 4-byte copy only if the offset is < 65536. Therefore + // the copy op takes 3 bytes to encode, and this type of item leads + // to at most the 62/60 blowup for representing literals. + // + // Suppose the "copy" op copies 5 bytes of data. If the offset is big + // enough, it will take 5 bytes to encode the copy op. Therefore the + // worst case here is a one-byte literal followed by a five-byte copy. + // That is, 6 bytes of input turn into 7 bytes of "compressed" data. + // + // This last factor dominates the blowup, so the final estimate is: + return 32 + srcLen + srcLen/6 +} diff --git a/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy.go b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy.go new file mode 100644 index 0000000000..2f1b790d0b --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy.go @@ -0,0 +1,38 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package snappy implements the snappy block-based compression format. +// It aims for very high speeds and reasonable compression. +// +// The C++ snappy implementation is at http://code.google.com/p/snappy/ +package snappy + +/* +Each encoded block begins with the varint-encoded length of the decoded data, +followed by a sequence of chunks. Chunks begin and end on byte boundaries. The +first byte of each chunk is broken into its 2 least and 6 most significant bits +called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag. +Zero means a literal tag. All other values mean a copy tag. + +For literal tags: + - If m < 60, the next 1 + m bytes are literal bytes. + - Otherwise, let n be the little-endian unsigned integer denoted by the next + m - 59 bytes. The next 1 + n bytes after that are literal bytes. + +For copy tags, length bytes are copied from offset bytes ago, in the style of +Lempel-Ziv compression algorithms. In particular: + - For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12). + The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10 + of the offset. The next byte is bits 0-7 of the offset. + - For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65). + The length is 1 + m. The offset is the little-endian unsigned integer + denoted by the next 2 bytes. + - For l == 3, this tag is a legacy format that is no longer supported. +*/ +const ( + tagLiteral = 0x00 + tagCopy1 = 0x01 + tagCopy2 = 0x02 + tagCopy4 = 0x03 +) diff --git a/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy_test.go b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy_test.go new file mode 100644 index 0000000000..7ba839244e --- /dev/null +++ b/Godeps/_workspace/src/github.com/syndtr/gosnappy/snappy/snappy_test.go @@ -0,0 +1,261 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net/http" + "os" + "path/filepath" + "strings" + "testing" +) + +var download = flag.Bool("download", false, "If true, download any missing files before running benchmarks") + +func roundtrip(b, ebuf, dbuf []byte) error { + e, err := Encode(ebuf, b) + if err != nil { + return fmt.Errorf("encoding error: %v", err) + } + d, err := Decode(dbuf, e) + if err != nil { + return fmt.Errorf("decoding error: %v", err) + } + if !bytes.Equal(b, d) { + return fmt.Errorf("roundtrip mismatch:\n\twant %v\n\tgot %v", b, d) + } + return nil +} + +func TestEmpty(t *testing.T) { + if err := roundtrip(nil, nil, nil); err != nil { + t.Fatal(err) + } +} + +func TestSmallCopy(t *testing.T) { + for _, ebuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} { + for _, dbuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} { + for i := 0; i < 32; i++ { + s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb" + if err := roundtrip([]byte(s), ebuf, dbuf); err != nil { + t.Errorf("len(ebuf)=%d, len(dbuf)=%d, i=%d: %v", len(ebuf), len(dbuf), i, err) + } + } + } + } +} + +func TestSmallRand(t *testing.T) { + rand.Seed(27354294) + for n := 1; n < 20000; n += 23 { + b := make([]byte, n) + for i, _ := range b { + b[i] = uint8(rand.Uint32()) + } + if err := roundtrip(b, nil, nil); err != nil { + t.Fatal(err) + } + } +} + +func TestSmallRegular(t *testing.T) { + for n := 1; n < 20000; n += 23 { + b := make([]byte, n) + for i, _ := range b { + b[i] = uint8(i%10 + 'a') + } + if err := roundtrip(b, nil, nil); err != nil { + t.Fatal(err) + } + } +} + +func benchDecode(b *testing.B, src []byte) { + encoded, err := Encode(nil, src) + if err != nil { + b.Fatal(err) + } + // Bandwidth is in amount of uncompressed data. + b.SetBytes(int64(len(src))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Decode(src, encoded) + } +} + +func benchEncode(b *testing.B, src []byte) { + // Bandwidth is in amount of uncompressed data. + b.SetBytes(int64(len(src))) + dst := make([]byte, MaxEncodedLen(len(src))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Encode(dst, src) + } +} + +func readFile(b *testing.B, filename string) []byte { + src, err := ioutil.ReadFile(filename) + if err != nil { + b.Fatalf("failed reading %s: %s", filename, err) + } + if len(src) == 0 { + b.Fatalf("%s has zero length", filename) + } + return src +} + +// expand returns a slice of length n containing repeated copies of src. +func expand(src []byte, n int) []byte { + dst := make([]byte, n) + for x := dst; len(x) > 0; { + i := copy(x, src) + x = x[i:] + } + return dst +} + +func benchWords(b *testing.B, n int, decode bool) { + // Note: the file is OS-language dependent so the resulting values are not + // directly comparable for non-US-English OS installations. + data := expand(readFile(b, "/usr/share/dict/words"), n) + if decode { + benchDecode(b, data) + } else { + benchEncode(b, data) + } +} + +func BenchmarkWordsDecode1e3(b *testing.B) { benchWords(b, 1e3, true) } +func BenchmarkWordsDecode1e4(b *testing.B) { benchWords(b, 1e4, true) } +func BenchmarkWordsDecode1e5(b *testing.B) { benchWords(b, 1e5, true) } +func BenchmarkWordsDecode1e6(b *testing.B) { benchWords(b, 1e6, true) } +func BenchmarkWordsEncode1e3(b *testing.B) { benchWords(b, 1e3, false) } +func BenchmarkWordsEncode1e4(b *testing.B) { benchWords(b, 1e4, false) } +func BenchmarkWordsEncode1e5(b *testing.B) { benchWords(b, 1e5, false) } +func BenchmarkWordsEncode1e6(b *testing.B) { benchWords(b, 1e6, false) } + +// testFiles' values are copied directly from +// https://code.google.com/p/snappy/source/browse/trunk/snappy_unittest.cc. +// The label field is unused in snappy-go. +var testFiles = []struct { + label string + filename string +}{ + {"html", "html"}, + {"urls", "urls.10K"}, + {"jpg", "house.jpg"}, + {"pdf", "mapreduce-osdi-1.pdf"}, + {"html4", "html_x_4"}, + {"cp", "cp.html"}, + {"c", "fields.c"}, + {"lsp", "grammar.lsp"}, + {"xls", "kennedy.xls"}, + {"txt1", "alice29.txt"}, + {"txt2", "asyoulik.txt"}, + {"txt3", "lcet10.txt"}, + {"txt4", "plrabn12.txt"}, + {"bin", "ptt5"}, + {"sum", "sum"}, + {"man", "xargs.1"}, + {"pb", "geo.protodata"}, + {"gaviota", "kppkn.gtb"}, +} + +// The test data files are present at this canonical URL. +const baseURL = "https://snappy.googlecode.com/svn/trunk/testdata/" + +func downloadTestdata(basename string) (errRet error) { + filename := filepath.Join("testdata", basename) + f, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to create %s: %s", filename, err) + } + defer f.Close() + defer func() { + if errRet != nil { + os.Remove(filename) + } + }() + resp, err := http.Get(baseURL + basename) + if err != nil { + return fmt.Errorf("failed to download %s: %s", baseURL+basename, err) + } + defer resp.Body.Close() + _, err = io.Copy(f, resp.Body) + if err != nil { + return fmt.Errorf("failed to write %s: %s", filename, err) + } + return nil +} + +func benchFile(b *testing.B, n int, decode bool) { + filename := filepath.Join("testdata", testFiles[n].filename) + if stat, err := os.Stat(filename); err != nil || stat.Size() == 0 { + if !*download { + b.Fatal("test data not found; skipping benchmark without the -download flag") + } + // Download the official snappy C++ implementation reference test data + // files for benchmarking. + if err := os.Mkdir("testdata", 0777); err != nil && !os.IsExist(err) { + b.Fatalf("failed to create testdata: %s", err) + } + for _, tf := range testFiles { + if err := downloadTestdata(tf.filename); err != nil { + b.Fatalf("failed to download testdata: %s", err) + } + } + } + data := readFile(b, filename) + if decode { + benchDecode(b, data) + } else { + benchEncode(b, data) + } +} + +// Naming convention is kept similar to what snappy's C++ implementation uses. +func Benchmark_UFlat0(b *testing.B) { benchFile(b, 0, true) } +func Benchmark_UFlat1(b *testing.B) { benchFile(b, 1, true) } +func Benchmark_UFlat2(b *testing.B) { benchFile(b, 2, true) } +func Benchmark_UFlat3(b *testing.B) { benchFile(b, 3, true) } +func Benchmark_UFlat4(b *testing.B) { benchFile(b, 4, true) } +func Benchmark_UFlat5(b *testing.B) { benchFile(b, 5, true) } +func Benchmark_UFlat6(b *testing.B) { benchFile(b, 6, true) } +func Benchmark_UFlat7(b *testing.B) { benchFile(b, 7, true) } +func Benchmark_UFlat8(b *testing.B) { benchFile(b, 8, true) } +func Benchmark_UFlat9(b *testing.B) { benchFile(b, 9, true) } +func Benchmark_UFlat10(b *testing.B) { benchFile(b, 10, true) } +func Benchmark_UFlat11(b *testing.B) { benchFile(b, 11, true) } +func Benchmark_UFlat12(b *testing.B) { benchFile(b, 12, true) } +func Benchmark_UFlat13(b *testing.B) { benchFile(b, 13, true) } +func Benchmark_UFlat14(b *testing.B) { benchFile(b, 14, true) } +func Benchmark_UFlat15(b *testing.B) { benchFile(b, 15, true) } +func Benchmark_UFlat16(b *testing.B) { benchFile(b, 16, true) } +func Benchmark_UFlat17(b *testing.B) { benchFile(b, 17, true) } +func Benchmark_ZFlat0(b *testing.B) { benchFile(b, 0, false) } +func Benchmark_ZFlat1(b *testing.B) { benchFile(b, 1, false) } +func Benchmark_ZFlat2(b *testing.B) { benchFile(b, 2, false) } +func Benchmark_ZFlat3(b *testing.B) { benchFile(b, 3, false) } +func Benchmark_ZFlat4(b *testing.B) { benchFile(b, 4, false) } +func Benchmark_ZFlat5(b *testing.B) { benchFile(b, 5, false) } +func Benchmark_ZFlat6(b *testing.B) { benchFile(b, 6, false) } +func Benchmark_ZFlat7(b *testing.B) { benchFile(b, 7, false) } +func Benchmark_ZFlat8(b *testing.B) { benchFile(b, 8, false) } +func Benchmark_ZFlat9(b *testing.B) { benchFile(b, 9, false) } +func Benchmark_ZFlat10(b *testing.B) { benchFile(b, 10, false) } +func Benchmark_ZFlat11(b *testing.B) { benchFile(b, 11, false) } +func Benchmark_ZFlat12(b *testing.B) { benchFile(b, 12, false) } +func Benchmark_ZFlat13(b *testing.B) { benchFile(b, 13, false) } +func Benchmark_ZFlat14(b *testing.B) { benchFile(b, 14, false) } +func Benchmark_ZFlat15(b *testing.B) { benchFile(b, 15, false) } +func Benchmark_ZFlat16(b *testing.B) { benchFile(b, 16, false) } +func Benchmark_ZFlat17(b *testing.B) { benchFile(b, 17, false) } diff --git a/Makefile b/Makefile index 930add4ee1..bbcc0e0631 100644 --- a/Makefile +++ b/Makefile @@ -66,6 +66,8 @@ config: dependencies $(MAKE) -C config dependencies: $(GOCC) $(FULL_GOPATH) + $(GO) get github.com/tools/godep + $(GODEP) restore $(GO) get -d documentation: search_index diff --git a/Makefile.INCLUDE b/Makefile.INCLUDE index 20dcfe0b09..ce43728503 100644 --- a/Makefile.INCLUDE +++ b/Makefile.INCLUDE @@ -41,6 +41,7 @@ GOURL ?= http://golang.org/dl GOROOT = $(BUILD_PATH)/root/go GOPATH = $(BUILD_PATH)/root/gopath GOCC = $(GOROOT)/bin/go +GODEP = $(GOPATH)/bin/godep TMPDIR = /tmp GOENV = TMPDIR=$(TMPDIR) GOROOT=$(GOROOT) GOPATH=$(GOPATH) GO = $(GOENV) $(GOCC)