main_test.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package main
  5. import (
  6. "go/ast"
  7. "go/parser"
  8. "strings"
  9. "testing"
  10. )
  11. type testCase struct {
  12. Name string
  13. Fn func(*ast.File) bool
  14. In string
  15. Out string
  16. }
  17. var testCases []testCase
  18. func addTestCases(t []testCase, fn func(*ast.File) bool) {
  19. // Fill in fn to avoid repetition in definitions.
  20. if fn != nil {
  21. for i := range t {
  22. if t[i].Fn == nil {
  23. t[i].Fn = fn
  24. }
  25. }
  26. }
  27. testCases = append(testCases, t...)
  28. }
  29. func fnop(*ast.File) bool { return false }
  30. func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
  31. file, err := parser.ParseFile(fset, desc, in, parserMode)
  32. if err != nil {
  33. t.Errorf("%s: parsing: %v", desc, err)
  34. return
  35. }
  36. outb, err := gofmtFile(file)
  37. if err != nil {
  38. t.Errorf("%s: printing: %v", desc, err)
  39. return
  40. }
  41. if s := string(outb); in != s && mustBeGofmt {
  42. t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
  43. desc, desc, in, desc, s)
  44. tdiff(t, in, s)
  45. return
  46. }
  47. if fn == nil {
  48. for _, fix := range fixes {
  49. if fix.f(file) {
  50. fixed = true
  51. }
  52. }
  53. } else {
  54. fixed = fn(file)
  55. }
  56. outb, err = gofmtFile(file)
  57. if err != nil {
  58. t.Errorf("%s: printing: %v", desc, err)
  59. return
  60. }
  61. return string(outb), fixed, true
  62. }
  63. func TestRewrite(t *testing.T) {
  64. for _, tt := range testCases {
  65. // Apply fix: should get tt.Out.
  66. out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
  67. if !ok {
  68. continue
  69. }
  70. // reformat to get printing right
  71. out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
  72. if !ok {
  73. continue
  74. }
  75. if out != tt.Out {
  76. t.Errorf("%s: incorrect output.\n", tt.Name)
  77. if !strings.HasPrefix(tt.Name, "testdata/") {
  78. t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
  79. }
  80. tdiff(t, out, tt.Out)
  81. continue
  82. }
  83. if changed := out != tt.In; changed != fixed {
  84. t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed)
  85. continue
  86. }
  87. // Should not change if run again.
  88. out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
  89. if !ok {
  90. continue
  91. }
  92. if fixed2 {
  93. t.Errorf("%s: applied fixes during second round", tt.Name)
  94. continue
  95. }
  96. if out2 != out {
  97. t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
  98. tt.Name, out, out2)
  99. tdiff(t, out, out2)
  100. }
  101. }
  102. }
  103. func tdiff(t *testing.T, a, b string) {
  104. data, err := diff([]byte(a), []byte(b))
  105. if err != nil {
  106. t.Error(err)
  107. return
  108. }
  109. t.Error(string(data))
  110. }