You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

formatimports.go 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. // Copyright 2021 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. package codeformat
  4. import (
  5. "bytes"
  6. "errors"
  7. "io"
  8. "os"
  9. "sort"
  10. "strings"
  11. )
  12. var importPackageGroupOrders = map[string]int{
  13. "": 1, // internal
  14. "code.gitea.io/gitea/": 2,
  15. }
  16. var errInvalidCommentBetweenImports = errors.New("comments between imported packages are invalid, please move comments to the end of the package line")
  17. var (
  18. importBlockBegin = []byte("\nimport (\n")
  19. importBlockEnd = []byte("\n)")
  20. )
  21. type importLineParsed struct {
  22. group string
  23. pkg string
  24. content string
  25. }
  26. func parseImportLine(line string) (*importLineParsed, error) {
  27. il := &importLineParsed{content: line}
  28. p1 := strings.IndexRune(line, '"')
  29. if p1 == -1 {
  30. return nil, errors.New("invalid import line: " + line)
  31. }
  32. p1++
  33. p := strings.IndexRune(line[p1:], '"')
  34. if p == -1 {
  35. return nil, errors.New("invalid import line: " + line)
  36. }
  37. p2 := p1 + p
  38. il.pkg = line[p1:p2]
  39. pDot := strings.IndexRune(il.pkg, '.')
  40. pSlash := strings.IndexRune(il.pkg, '/')
  41. if pDot != -1 && pDot < pSlash {
  42. il.group = "domain-package"
  43. }
  44. for groupName := range importPackageGroupOrders {
  45. if groupName == "" {
  46. continue // skip internal
  47. }
  48. if strings.HasPrefix(il.pkg, groupName) {
  49. il.group = groupName
  50. }
  51. }
  52. return il, nil
  53. }
  54. type (
  55. importLineGroup []*importLineParsed
  56. importLineGroupMap map[string]importLineGroup
  57. )
  58. func formatGoImports(contentBytes []byte) ([]byte, error) {
  59. p1 := bytes.Index(contentBytes, importBlockBegin)
  60. if p1 == -1 {
  61. return nil, nil
  62. }
  63. p1 += len(importBlockBegin)
  64. p := bytes.Index(contentBytes[p1:], importBlockEnd)
  65. if p == -1 {
  66. return nil, nil
  67. }
  68. p2 := p1 + p
  69. importGroups := importLineGroupMap{}
  70. r := bytes.NewBuffer(contentBytes[p1:p2])
  71. eof := false
  72. for !eof {
  73. line, err := r.ReadString('\n')
  74. eof = err == io.EOF
  75. if err != nil && !eof {
  76. return nil, err
  77. }
  78. line = strings.TrimSpace(line)
  79. if line != "" {
  80. if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "/*") {
  81. return nil, errInvalidCommentBetweenImports
  82. }
  83. importLine, err := parseImportLine(line)
  84. if err != nil {
  85. return nil, err
  86. }
  87. importGroups[importLine.group] = append(importGroups[importLine.group], importLine)
  88. }
  89. }
  90. var groupNames []string
  91. for groupName, importLines := range importGroups {
  92. groupNames = append(groupNames, groupName)
  93. sort.Slice(importLines, func(i, j int) bool {
  94. return strings.Compare(importLines[i].pkg, importLines[j].pkg) < 0
  95. })
  96. }
  97. sort.Slice(groupNames, func(i, j int) bool {
  98. n1 := groupNames[i]
  99. n2 := groupNames[j]
  100. o1 := importPackageGroupOrders[n1]
  101. o2 := importPackageGroupOrders[n2]
  102. if o1 != 0 && o2 != 0 {
  103. return o1 < o2
  104. }
  105. if o1 == 0 && o2 == 0 {
  106. return strings.Compare(n1, n2) < 0
  107. }
  108. return o1 != 0
  109. })
  110. formattedBlock := bytes.Buffer{}
  111. for _, groupName := range groupNames {
  112. hasNormalImports := false
  113. hasDummyImports := false
  114. // non-dummy import comes first
  115. for _, importLine := range importGroups[groupName] {
  116. if strings.HasPrefix(importLine.content, "_") {
  117. hasDummyImports = true
  118. } else {
  119. formattedBlock.WriteString("\t" + importLine.content + "\n")
  120. hasNormalImports = true
  121. }
  122. }
  123. // dummy (_ "pkg") comes later
  124. if hasDummyImports {
  125. if hasNormalImports {
  126. formattedBlock.WriteString("\n")
  127. }
  128. for _, importLine := range importGroups[groupName] {
  129. if strings.HasPrefix(importLine.content, "_") {
  130. formattedBlock.WriteString("\t" + importLine.content + "\n")
  131. }
  132. }
  133. }
  134. formattedBlock.WriteString("\n")
  135. }
  136. formattedBlockBytes := bytes.TrimRight(formattedBlock.Bytes(), "\n")
  137. var formattedBytes []byte
  138. formattedBytes = append(formattedBytes, contentBytes[:p1]...)
  139. formattedBytes = append(formattedBytes, formattedBlockBytes...)
  140. formattedBytes = append(formattedBytes, contentBytes[p2:]...)
  141. return formattedBytes, nil
  142. }
  143. // FormatGoImports format the imports by our rules (see unit tests)
  144. func FormatGoImports(file string, doWriteFile bool) error {
  145. f, err := os.Open(file)
  146. if err != nil {
  147. return err
  148. }
  149. var contentBytes []byte
  150. {
  151. defer f.Close()
  152. contentBytes, err = io.ReadAll(f)
  153. if err != nil {
  154. return err
  155. }
  156. }
  157. formattedBytes, err := formatGoImports(contentBytes)
  158. if err != nil {
  159. return err
  160. }
  161. if formattedBytes == nil {
  162. return nil
  163. }
  164. if bytes.Equal(contentBytes, formattedBytes) {
  165. return nil
  166. }
  167. if doWriteFile {
  168. f, err = os.OpenFile(file, os.O_TRUNC|os.O_WRONLY, 0o644)
  169. if err != nil {
  170. return err
  171. }
  172. defer f.Close()
  173. _, err = f.Write(formattedBytes)
  174. return err
  175. }
  176. return err
  177. }