main.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package main
  5. import (
  6. "bytes"
  7. "flag"
  8. "fmt"
  9. "go/ast"
  10. "go/format"
  11. "go/parser"
  12. "go/scanner"
  13. "go/token"
  14. "io/ioutil"
  15. "os"
  16. "os/exec"
  17. "path/filepath"
  18. "sort"
  19. "strings"
  20. )
  21. var (
  22. fset = token.NewFileSet()
  23. exitCode = 0
  24. )
  25. var allowedRewrites = flag.String("r", "",
  26. "restrict the rewrites to this comma-separated list")
  27. var forceRewrites = flag.String("force", "",
  28. "force these fixes to run even if the code looks updated")
  29. var allowed, force map[string]bool
  30. var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
  31. // enable for debugging fix failures
  32. const debug = false // display incorrectly reformatted source and exit
  33. func usage() {
  34. fmt.Fprintf(os.Stderr, "usage: aefix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
  35. flag.PrintDefaults()
  36. fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
  37. sort.Sort(byName(fixes))
  38. for _, f := range fixes {
  39. fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
  40. desc := strings.TrimSpace(f.desc)
  41. desc = strings.Replace(desc, "\n", "\n\t", -1)
  42. fmt.Fprintf(os.Stderr, "\t%s\n", desc)
  43. }
  44. os.Exit(2)
  45. }
  46. func main() {
  47. flag.Usage = usage
  48. flag.Parse()
  49. sort.Sort(byDate(fixes))
  50. if *allowedRewrites != "" {
  51. allowed = make(map[string]bool)
  52. for _, f := range strings.Split(*allowedRewrites, ",") {
  53. allowed[f] = true
  54. }
  55. }
  56. if *forceRewrites != "" {
  57. force = make(map[string]bool)
  58. for _, f := range strings.Split(*forceRewrites, ",") {
  59. force[f] = true
  60. }
  61. }
  62. if flag.NArg() == 0 {
  63. if err := processFile("standard input", true); err != nil {
  64. report(err)
  65. }
  66. os.Exit(exitCode)
  67. }
  68. for i := 0; i < flag.NArg(); i++ {
  69. path := flag.Arg(i)
  70. switch dir, err := os.Stat(path); {
  71. case err != nil:
  72. report(err)
  73. case dir.IsDir():
  74. walkDir(path)
  75. default:
  76. if err := processFile(path, false); err != nil {
  77. report(err)
  78. }
  79. }
  80. }
  81. os.Exit(exitCode)
  82. }
  83. const parserMode = parser.ParseComments
  84. func gofmtFile(f *ast.File) ([]byte, error) {
  85. var buf bytes.Buffer
  86. if err := format.Node(&buf, fset, f); err != nil {
  87. return nil, err
  88. }
  89. return buf.Bytes(), nil
  90. }
  91. func processFile(filename string, useStdin bool) error {
  92. var f *os.File
  93. var err error
  94. var fixlog bytes.Buffer
  95. if useStdin {
  96. f = os.Stdin
  97. } else {
  98. f, err = os.Open(filename)
  99. if err != nil {
  100. return err
  101. }
  102. defer f.Close()
  103. }
  104. src, err := ioutil.ReadAll(f)
  105. if err != nil {
  106. return err
  107. }
  108. file, err := parser.ParseFile(fset, filename, src, parserMode)
  109. if err != nil {
  110. return err
  111. }
  112. // Apply all fixes to file.
  113. newFile := file
  114. fixed := false
  115. for _, fix := range fixes {
  116. if allowed != nil && !allowed[fix.name] {
  117. continue
  118. }
  119. if fix.f(newFile) {
  120. fixed = true
  121. fmt.Fprintf(&fixlog, " %s", fix.name)
  122. // AST changed.
  123. // Print and parse, to update any missing scoping
  124. // or position information for subsequent fixers.
  125. newSrc, err := gofmtFile(newFile)
  126. if err != nil {
  127. return err
  128. }
  129. newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
  130. if err != nil {
  131. if debug {
  132. fmt.Printf("%s", newSrc)
  133. report(err)
  134. os.Exit(exitCode)
  135. }
  136. return err
  137. }
  138. }
  139. }
  140. if !fixed {
  141. return nil
  142. }
  143. fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
  144. // Print AST. We did that after each fix, so this appears
  145. // redundant, but it is necessary to generate gofmt-compatible
  146. // source code in a few cases. The official gofmt style is the
  147. // output of the printer run on a standard AST generated by the parser,
  148. // but the source we generated inside the loop above is the
  149. // output of the printer run on a mangled AST generated by a fixer.
  150. newSrc, err := gofmtFile(newFile)
  151. if err != nil {
  152. return err
  153. }
  154. if *doDiff {
  155. data, err := diff(src, newSrc)
  156. if err != nil {
  157. return fmt.Errorf("computing diff: %s", err)
  158. }
  159. fmt.Printf("diff %s fixed/%s\n", filename, filename)
  160. os.Stdout.Write(data)
  161. return nil
  162. }
  163. if useStdin {
  164. os.Stdout.Write(newSrc)
  165. return nil
  166. }
  167. return ioutil.WriteFile(f.Name(), newSrc, 0)
  168. }
  169. var gofmtBuf bytes.Buffer
  170. func gofmt(n interface{}) string {
  171. gofmtBuf.Reset()
  172. if err := format.Node(&gofmtBuf, fset, n); err != nil {
  173. return "<" + err.Error() + ">"
  174. }
  175. return gofmtBuf.String()
  176. }
  177. func report(err error) {
  178. scanner.PrintError(os.Stderr, err)
  179. exitCode = 2
  180. }
  181. func walkDir(path string) {
  182. filepath.Walk(path, visitFile)
  183. }
  184. func visitFile(path string, f os.FileInfo, err error) error {
  185. if err == nil && isGoFile(f) {
  186. err = processFile(path, false)
  187. }
  188. if err != nil {
  189. report(err)
  190. }
  191. return nil
  192. }
  193. func isGoFile(f os.FileInfo) bool {
  194. // ignore non-Go files
  195. name := f.Name()
  196. return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
  197. }
  198. func diff(b1, b2 []byte) (data []byte, err error) {
  199. f1, err := ioutil.TempFile("", "go-fix")
  200. if err != nil {
  201. return nil, err
  202. }
  203. defer os.Remove(f1.Name())
  204. defer f1.Close()
  205. f2, err := ioutil.TempFile("", "go-fix")
  206. if err != nil {
  207. return nil, err
  208. }
  209. defer os.Remove(f2.Name())
  210. defer f2.Close()
  211. f1.Write(b1)
  212. f2.Write(b2)
  213. data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
  214. if len(data) > 0 {
  215. // diff exits with a non-zero status when the files don't match.
  216. // Ignore that failure as long as we get output.
  217. err = nil
  218. }
  219. return
  220. }