extensions_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. // Go support for Protocol Buffers - Google's data interchange format
  2. //
  3. // Copyright 2014 The Go Authors. All rights reserved.
  4. // https://github.com/golang/protobuf
  5. //
  6. // Redistribution and use in source and binary forms, with or without
  7. // modification, are permitted provided that the following conditions are
  8. // met:
  9. //
  10. // * Redistributions of source code must retain the above copyright
  11. // notice, this list of conditions and the following disclaimer.
  12. // * Redistributions in binary form must reproduce the above
  13. // copyright notice, this list of conditions and the following disclaimer
  14. // in the documentation and/or other materials provided with the
  15. // distribution.
  16. // * Neither the name of Google Inc. nor the names of its
  17. // contributors may be used to endorse or promote products derived from
  18. // this software without specific prior written permission.
  19. //
  20. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  21. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  22. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  23. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  24. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  25. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  26. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  27. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  28. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  29. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  30. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  31. package proto_test
  32. import (
  33. "bytes"
  34. "fmt"
  35. "reflect"
  36. "sort"
  37. "testing"
  38. "github.com/golang/protobuf/proto"
  39. pb "github.com/golang/protobuf/proto/testdata"
  40. "golang.org/x/sync/errgroup"
  41. )
  42. func TestGetExtensionsWithMissingExtensions(t *testing.T) {
  43. msg := &pb.MyMessage{}
  44. ext1 := &pb.Ext{}
  45. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  46. t.Fatalf("Could not set ext1: %s", err)
  47. }
  48. exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
  49. pb.E_Ext_More,
  50. pb.E_Ext_Text,
  51. })
  52. if err != nil {
  53. t.Fatalf("GetExtensions() failed: %s", err)
  54. }
  55. if exts[0] != ext1 {
  56. t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
  57. }
  58. if exts[1] != nil {
  59. t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
  60. }
  61. }
  62. func TestExtensionDescsWithMissingExtensions(t *testing.T) {
  63. msg := &pb.MyMessage{Count: proto.Int32(0)}
  64. extdesc1 := pb.E_Ext_More
  65. if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
  66. t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
  67. }
  68. ext1 := &pb.Ext{}
  69. if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
  70. t.Fatalf("Could not set ext1: %s", err)
  71. }
  72. extdesc2 := &proto.ExtensionDesc{
  73. ExtendedType: (*pb.MyMessage)(nil),
  74. ExtensionType: (*bool)(nil),
  75. Field: 123456789,
  76. Name: "a.b",
  77. Tag: "varint,123456789,opt",
  78. }
  79. ext2 := proto.Bool(false)
  80. if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
  81. t.Fatalf("Could not set ext2: %s", err)
  82. }
  83. b, err := proto.Marshal(msg)
  84. if err != nil {
  85. t.Fatalf("Could not marshal msg: %v", err)
  86. }
  87. if err := proto.Unmarshal(b, msg); err != nil {
  88. t.Fatalf("Could not unmarshal into msg: %v", err)
  89. }
  90. descs, err := proto.ExtensionDescs(msg)
  91. if err != nil {
  92. t.Fatalf("proto.ExtensionDescs: got error %v", err)
  93. }
  94. sortExtDescs(descs)
  95. wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}}
  96. if !reflect.DeepEqual(descs, wantDescs) {
  97. t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
  98. }
  99. }
  100. type ExtensionDescSlice []*proto.ExtensionDesc
  101. func (s ExtensionDescSlice) Len() int { return len(s) }
  102. func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
  103. func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
  104. func sortExtDescs(s []*proto.ExtensionDesc) {
  105. sort.Sort(ExtensionDescSlice(s))
  106. }
  107. func TestGetExtensionStability(t *testing.T) {
  108. check := func(m *pb.MyMessage) bool {
  109. ext1, err := proto.GetExtension(m, pb.E_Ext_More)
  110. if err != nil {
  111. t.Fatalf("GetExtension() failed: %s", err)
  112. }
  113. ext2, err := proto.GetExtension(m, pb.E_Ext_More)
  114. if err != nil {
  115. t.Fatalf("GetExtension() failed: %s", err)
  116. }
  117. return ext1 == ext2
  118. }
  119. msg := &pb.MyMessage{Count: proto.Int32(4)}
  120. ext0 := &pb.Ext{}
  121. if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
  122. t.Fatalf("Could not set ext1: %s", ext0)
  123. }
  124. if !check(msg) {
  125. t.Errorf("GetExtension() not stable before marshaling")
  126. }
  127. bb, err := proto.Marshal(msg)
  128. if err != nil {
  129. t.Fatalf("Marshal() failed: %s", err)
  130. }
  131. msg1 := &pb.MyMessage{}
  132. err = proto.Unmarshal(bb, msg1)
  133. if err != nil {
  134. t.Fatalf("Unmarshal() failed: %s", err)
  135. }
  136. if !check(msg1) {
  137. t.Errorf("GetExtension() not stable after unmarshaling")
  138. }
  139. }
  140. func TestGetExtensionDefaults(t *testing.T) {
  141. var setFloat64 float64 = 1
  142. var setFloat32 float32 = 2
  143. var setInt32 int32 = 3
  144. var setInt64 int64 = 4
  145. var setUint32 uint32 = 5
  146. var setUint64 uint64 = 6
  147. var setBool = true
  148. var setBool2 = false
  149. var setString = "Goodnight string"
  150. var setBytes = []byte("Goodnight bytes")
  151. var setEnum = pb.DefaultsMessage_TWO
  152. type testcase struct {
  153. ext *proto.ExtensionDesc // Extension we are testing.
  154. want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
  155. def interface{} // Expected value of extension after ClearExtension().
  156. }
  157. tests := []testcase{
  158. {pb.E_NoDefaultDouble, setFloat64, nil},
  159. {pb.E_NoDefaultFloat, setFloat32, nil},
  160. {pb.E_NoDefaultInt32, setInt32, nil},
  161. {pb.E_NoDefaultInt64, setInt64, nil},
  162. {pb.E_NoDefaultUint32, setUint32, nil},
  163. {pb.E_NoDefaultUint64, setUint64, nil},
  164. {pb.E_NoDefaultSint32, setInt32, nil},
  165. {pb.E_NoDefaultSint64, setInt64, nil},
  166. {pb.E_NoDefaultFixed32, setUint32, nil},
  167. {pb.E_NoDefaultFixed64, setUint64, nil},
  168. {pb.E_NoDefaultSfixed32, setInt32, nil},
  169. {pb.E_NoDefaultSfixed64, setInt64, nil},
  170. {pb.E_NoDefaultBool, setBool, nil},
  171. {pb.E_NoDefaultBool, setBool2, nil},
  172. {pb.E_NoDefaultString, setString, nil},
  173. {pb.E_NoDefaultBytes, setBytes, nil},
  174. {pb.E_NoDefaultEnum, setEnum, nil},
  175. {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
  176. {pb.E_DefaultFloat, setFloat32, float32(3.14)},
  177. {pb.E_DefaultInt32, setInt32, int32(42)},
  178. {pb.E_DefaultInt64, setInt64, int64(43)},
  179. {pb.E_DefaultUint32, setUint32, uint32(44)},
  180. {pb.E_DefaultUint64, setUint64, uint64(45)},
  181. {pb.E_DefaultSint32, setInt32, int32(46)},
  182. {pb.E_DefaultSint64, setInt64, int64(47)},
  183. {pb.E_DefaultFixed32, setUint32, uint32(48)},
  184. {pb.E_DefaultFixed64, setUint64, uint64(49)},
  185. {pb.E_DefaultSfixed32, setInt32, int32(50)},
  186. {pb.E_DefaultSfixed64, setInt64, int64(51)},
  187. {pb.E_DefaultBool, setBool, true},
  188. {pb.E_DefaultBool, setBool2, true},
  189. {pb.E_DefaultString, setString, "Hello, string"},
  190. {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
  191. {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
  192. }
  193. checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
  194. val, err := proto.GetExtension(msg, test.ext)
  195. if err != nil {
  196. if valWant != nil {
  197. return fmt.Errorf("GetExtension(): %s", err)
  198. }
  199. if want := proto.ErrMissingExtension; err != want {
  200. return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
  201. }
  202. return nil
  203. }
  204. // All proto2 extension values are either a pointer to a value or a slice of values.
  205. ty := reflect.TypeOf(val)
  206. tyWant := reflect.TypeOf(test.ext.ExtensionType)
  207. if got, want := ty, tyWant; got != want {
  208. return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
  209. }
  210. tye := ty.Elem()
  211. tyeWant := tyWant.Elem()
  212. if got, want := tye, tyeWant; got != want {
  213. return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
  214. }
  215. // Check the name of the type of the value.
  216. // If it is an enum it will be type int32 with the name of the enum.
  217. if got, want := tye.Name(), tye.Name(); got != want {
  218. return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
  219. }
  220. // Check that value is what we expect.
  221. // If we have a pointer in val, get the value it points to.
  222. valExp := val
  223. if ty.Kind() == reflect.Ptr {
  224. valExp = reflect.ValueOf(val).Elem().Interface()
  225. }
  226. if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
  227. return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
  228. }
  229. return nil
  230. }
  231. setTo := func(test testcase) interface{} {
  232. setTo := reflect.ValueOf(test.want)
  233. if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
  234. setTo = reflect.New(typ).Elem()
  235. setTo.Set(reflect.New(setTo.Type().Elem()))
  236. setTo.Elem().Set(reflect.ValueOf(test.want))
  237. }
  238. return setTo.Interface()
  239. }
  240. for _, test := range tests {
  241. msg := &pb.DefaultsMessage{}
  242. name := test.ext.Name
  243. // Check the initial value.
  244. if err := checkVal(test, msg, test.def); err != nil {
  245. t.Errorf("%s: %v", name, err)
  246. }
  247. // Set the per-type value and check value.
  248. name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
  249. if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
  250. t.Errorf("%s: SetExtension(): %v", name, err)
  251. continue
  252. }
  253. if err := checkVal(test, msg, test.want); err != nil {
  254. t.Errorf("%s: %v", name, err)
  255. continue
  256. }
  257. // Set and check the value.
  258. name += " (cleared)"
  259. proto.ClearExtension(msg, test.ext)
  260. if err := checkVal(test, msg, test.def); err != nil {
  261. t.Errorf("%s: %v", name, err)
  262. }
  263. }
  264. }
  265. func TestExtensionsRoundTrip(t *testing.T) {
  266. msg := &pb.MyMessage{}
  267. ext1 := &pb.Ext{
  268. Data: proto.String("hi"),
  269. }
  270. ext2 := &pb.Ext{
  271. Data: proto.String("there"),
  272. }
  273. exists := proto.HasExtension(msg, pb.E_Ext_More)
  274. if exists {
  275. t.Error("Extension More present unexpectedly")
  276. }
  277. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  278. t.Error(err)
  279. }
  280. if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
  281. t.Error(err)
  282. }
  283. e, err := proto.GetExtension(msg, pb.E_Ext_More)
  284. if err != nil {
  285. t.Error(err)
  286. }
  287. x, ok := e.(*pb.Ext)
  288. if !ok {
  289. t.Errorf("e has type %T, expected testdata.Ext", e)
  290. } else if *x.Data != "there" {
  291. t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
  292. }
  293. proto.ClearExtension(msg, pb.E_Ext_More)
  294. if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
  295. t.Errorf("got %v, expected ErrMissingExtension", e)
  296. }
  297. if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
  298. t.Error("expected bad extension error, got nil")
  299. }
  300. if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
  301. t.Error("expected extension err")
  302. }
  303. if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
  304. t.Error("expected some sort of type mismatch error, got nil")
  305. }
  306. }
  307. func TestNilExtension(t *testing.T) {
  308. msg := &pb.MyMessage{
  309. Count: proto.Int32(1),
  310. }
  311. if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
  312. t.Fatal(err)
  313. }
  314. if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
  315. t.Error("expected SetExtension to fail due to a nil extension")
  316. } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
  317. t.Errorf("expected error %v, got %v", want, err)
  318. }
  319. // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
  320. // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
  321. }
  322. func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
  323. // Add a repeated extension to the result.
  324. tests := []struct {
  325. name string
  326. ext []*pb.ComplexExtension
  327. }{
  328. {
  329. "two fields",
  330. []*pb.ComplexExtension{
  331. {First: proto.Int32(7)},
  332. {Second: proto.Int32(11)},
  333. },
  334. },
  335. {
  336. "repeated field",
  337. []*pb.ComplexExtension{
  338. {Third: []int32{1000}},
  339. {Third: []int32{2000}},
  340. },
  341. },
  342. {
  343. "two fields and repeated field",
  344. []*pb.ComplexExtension{
  345. {Third: []int32{1000}},
  346. {First: proto.Int32(9)},
  347. {Second: proto.Int32(21)},
  348. {Third: []int32{2000}},
  349. },
  350. },
  351. }
  352. for _, test := range tests {
  353. // Marshal message with a repeated extension.
  354. msg1 := new(pb.OtherMessage)
  355. err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
  356. if err != nil {
  357. t.Fatalf("[%s] Error setting extension: %v", test.name, err)
  358. }
  359. b, err := proto.Marshal(msg1)
  360. if err != nil {
  361. t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
  362. }
  363. // Unmarshal and read the merged proto.
  364. msg2 := new(pb.OtherMessage)
  365. err = proto.Unmarshal(b, msg2)
  366. if err != nil {
  367. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  368. }
  369. e, err := proto.GetExtension(msg2, pb.E_RComplex)
  370. if err != nil {
  371. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  372. }
  373. ext := e.([]*pb.ComplexExtension)
  374. if ext == nil {
  375. t.Fatalf("[%s] Invalid extension", test.name)
  376. }
  377. if !reflect.DeepEqual(ext, test.ext) {
  378. t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
  379. }
  380. }
  381. }
  382. func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
  383. // We may see multiple instances of the same extension in the wire
  384. // format. For example, the proto compiler may encode custom options in
  385. // this way. Here, we verify that we merge the extensions together.
  386. tests := []struct {
  387. name string
  388. ext []*pb.ComplexExtension
  389. }{
  390. {
  391. "two fields",
  392. []*pb.ComplexExtension{
  393. {First: proto.Int32(7)},
  394. {Second: proto.Int32(11)},
  395. },
  396. },
  397. {
  398. "repeated field",
  399. []*pb.ComplexExtension{
  400. {Third: []int32{1000}},
  401. {Third: []int32{2000}},
  402. },
  403. },
  404. {
  405. "two fields and repeated field",
  406. []*pb.ComplexExtension{
  407. {Third: []int32{1000}},
  408. {First: proto.Int32(9)},
  409. {Second: proto.Int32(21)},
  410. {Third: []int32{2000}},
  411. },
  412. },
  413. }
  414. for _, test := range tests {
  415. var buf bytes.Buffer
  416. var want pb.ComplexExtension
  417. // Generate a serialized representation of a repeated extension
  418. // by catenating bytes together.
  419. for i, e := range test.ext {
  420. // Merge to create the wanted proto.
  421. proto.Merge(&want, e)
  422. // serialize the message
  423. msg := new(pb.OtherMessage)
  424. err := proto.SetExtension(msg, pb.E_Complex, e)
  425. if err != nil {
  426. t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
  427. }
  428. b, err := proto.Marshal(msg)
  429. if err != nil {
  430. t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
  431. }
  432. buf.Write(b)
  433. }
  434. // Unmarshal and read the merged proto.
  435. msg2 := new(pb.OtherMessage)
  436. err := proto.Unmarshal(buf.Bytes(), msg2)
  437. if err != nil {
  438. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  439. }
  440. e, err := proto.GetExtension(msg2, pb.E_Complex)
  441. if err != nil {
  442. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  443. }
  444. ext := e.(*pb.ComplexExtension)
  445. if ext == nil {
  446. t.Fatalf("[%s] Invalid extension", test.name)
  447. }
  448. if !reflect.DeepEqual(*ext, want) {
  449. t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want)
  450. }
  451. }
  452. }
  453. func TestClearAllExtensions(t *testing.T) {
  454. // unregistered extension
  455. desc := &proto.ExtensionDesc{
  456. ExtendedType: (*pb.MyMessage)(nil),
  457. ExtensionType: (*bool)(nil),
  458. Field: 101010100,
  459. Name: "emptyextension",
  460. Tag: "varint,0,opt",
  461. }
  462. m := &pb.MyMessage{}
  463. if proto.HasExtension(m, desc) {
  464. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  465. }
  466. if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
  467. t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
  468. }
  469. if !proto.HasExtension(m, desc) {
  470. t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
  471. }
  472. proto.ClearAllExtensions(m)
  473. if proto.HasExtension(m, desc) {
  474. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  475. }
  476. }
  477. func TestMarshalRace(t *testing.T) {
  478. // unregistered extension
  479. desc := &proto.ExtensionDesc{
  480. ExtendedType: (*pb.MyMessage)(nil),
  481. ExtensionType: (*bool)(nil),
  482. Field: 101010100,
  483. Name: "emptyextension",
  484. Tag: "varint,0,opt",
  485. }
  486. m := &pb.MyMessage{Count: proto.Int32(4)}
  487. if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
  488. t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
  489. }
  490. var g errgroup.Group
  491. for n := 3; n > 0; n-- {
  492. g.Go(func() error {
  493. _, err := proto.Marshal(m)
  494. return err
  495. })
  496. }
  497. if err := g.Wait(); err != nil {
  498. t.Fatal(err)
  499. }
  500. }