You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

formatimports.go 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. // Copyright 2021 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package codeformat
  5. import (
  6. "bytes"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "os"
  11. "sort"
  12. "strings"
  13. )
  14. var importPackageGroupOrders = map[string]int{
  15. "": 1, // internal
  16. "code.gitea.io/gitea/": 2,
  17. }
  18. var errInvalidCommentBetweenImports = errors.New("comments between imported packages are invalid, please move comments to the end of the package line")
  19. var (
  20. importBlockBegin = []byte("\nimport (\n")
  21. importBlockEnd = []byte("\n)")
  22. )
  23. type importLineParsed struct {
  24. group string
  25. pkg string
  26. content string
  27. }
  28. func parseImportLine(line string) (*importLineParsed, error) {
  29. il := &importLineParsed{content: line}
  30. p1 := strings.IndexRune(line, '"')
  31. if p1 == -1 {
  32. return nil, errors.New("invalid import line: " + line)
  33. }
  34. p1++
  35. p := strings.IndexRune(line[p1:], '"')
  36. if p == -1 {
  37. return nil, errors.New("invalid import line: " + line)
  38. }
  39. p2 := p1 + p
  40. il.pkg = line[p1:p2]
  41. pDot := strings.IndexRune(il.pkg, '.')
  42. pSlash := strings.IndexRune(il.pkg, '/')
  43. if pDot != -1 && pDot < pSlash {
  44. il.group = "domain-package"
  45. }
  46. for groupName := range importPackageGroupOrders {
  47. if groupName == "" {
  48. continue // skip internal
  49. }
  50. if strings.HasPrefix(il.pkg, groupName) {
  51. il.group = groupName
  52. }
  53. }
  54. return il, nil
  55. }
  56. type (
  57. importLineGroup []*importLineParsed
  58. importLineGroupMap map[string]importLineGroup
  59. )
  60. func formatGoImports(contentBytes []byte) ([]byte, error) {
  61. p1 := bytes.Index(contentBytes, importBlockBegin)
  62. if p1 == -1 {
  63. return nil, nil
  64. }
  65. p1 += len(importBlockBegin)
  66. p := bytes.Index(contentBytes[p1:], importBlockEnd)
  67. if p == -1 {
  68. return nil, nil
  69. }
  70. p2 := p1 + p
  71. importGroups := importLineGroupMap{}
  72. r := bytes.NewBuffer(contentBytes[p1:p2])
  73. eof := false
  74. for !eof {
  75. line, err := r.ReadString('\n')
  76. eof = err == io.EOF
  77. if err != nil && !eof {
  78. return nil, err
  79. }
  80. line = strings.TrimSpace(line)
  81. if line != "" {
  82. if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "/*") {
  83. return nil, errInvalidCommentBetweenImports
  84. }
  85. importLine, err := parseImportLine(line)
  86. if err != nil {
  87. return nil, err
  88. }
  89. importGroups[importLine.group] = append(importGroups[importLine.group], importLine)
  90. }
  91. }
  92. var groupNames []string
  93. for groupName, importLines := range importGroups {
  94. groupNames = append(groupNames, groupName)
  95. sort.Slice(importLines, func(i, j int) bool {
  96. return strings.Compare(importLines[i].pkg, importLines[j].pkg) < 0
  97. })
  98. }
  99. sort.Slice(groupNames, func(i, j int) bool {
  100. n1 := groupNames[i]
  101. n2 := groupNames[j]
  102. o1 := importPackageGroupOrders[n1]
  103. o2 := importPackageGroupOrders[n2]
  104. if o1 != 0 && o2 != 0 {
  105. return o1 < o2
  106. }
  107. if o1 == 0 && o2 == 0 {
  108. return strings.Compare(n1, n2) < 0
  109. }
  110. return o1 != 0
  111. })
  112. formattedBlock := bytes.Buffer{}
  113. for _, groupName := range groupNames {
  114. hasNormalImports := false
  115. hasDummyImports := false
  116. // non-dummy import comes first
  117. for _, importLine := range importGroups[groupName] {
  118. if strings.HasPrefix(importLine.content, "_") {
  119. hasDummyImports = true
  120. } else {
  121. formattedBlock.WriteString("\t" + importLine.content + "\n")
  122. hasNormalImports = true
  123. }
  124. }
  125. // dummy (_ "pkg") comes later
  126. if hasDummyImports {
  127. if hasNormalImports {
  128. formattedBlock.WriteString("\n")
  129. }
  130. for _, importLine := range importGroups[groupName] {
  131. if strings.HasPrefix(importLine.content, "_") {
  132. formattedBlock.WriteString("\t" + importLine.content + "\n")
  133. }
  134. }
  135. }
  136. formattedBlock.WriteString("\n")
  137. }
  138. formattedBlockBytes := bytes.TrimRight(formattedBlock.Bytes(), "\n")
  139. var formattedBytes []byte
  140. formattedBytes = append(formattedBytes, contentBytes[:p1]...)
  141. formattedBytes = append(formattedBytes, formattedBlockBytes...)
  142. formattedBytes = append(formattedBytes, contentBytes[p2:]...)
  143. return formattedBytes, nil
  144. }
  145. // FormatGoImports format the imports by our rules (see unit tests)
  146. func FormatGoImports(file string, doChangedFiles, doWriteFile bool) error {
  147. f, err := os.Open(file)
  148. if err != nil {
  149. return err
  150. }
  151. var contentBytes []byte
  152. {
  153. defer f.Close()
  154. contentBytes, err = io.ReadAll(f)
  155. if err != nil {
  156. return err
  157. }
  158. }
  159. formattedBytes, err := formatGoImports(contentBytes)
  160. if err != nil {
  161. return err
  162. }
  163. if formattedBytes == nil {
  164. return nil
  165. }
  166. if bytes.Equal(contentBytes, formattedBytes) {
  167. return nil
  168. }
  169. if doChangedFiles {
  170. fmt.Println(file)
  171. }
  172. if doWriteFile {
  173. f, err = os.OpenFile(file, os.O_TRUNC|os.O_WRONLY, 0o644)
  174. if err != nil {
  175. return err
  176. }
  177. defer f.Close()
  178. _, err = f.Write(formattedBytes)
  179. return err
  180. }
  181. return err
  182. }