123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- // Copyright 2021 The Gitea Authors. All rights reserved.
- // SPDX-License-Identifier: MIT
-
- package codeformat
-
- import (
- "bytes"
- "errors"
- "io"
- "os"
- "sort"
- "strings"
- )
-
- var importPackageGroupOrders = map[string]int{
- "": 1, // internal
- "code.gitea.io/gitea/": 2,
- }
-
- var errInvalidCommentBetweenImports = errors.New("comments between imported packages are invalid, please move comments to the end of the package line")
-
- var (
- importBlockBegin = []byte("\nimport (\n")
- importBlockEnd = []byte("\n)")
- )
-
- type importLineParsed struct {
- group string
- pkg string
- content string
- }
-
- func parseImportLine(line string) (*importLineParsed, error) {
- il := &importLineParsed{content: line}
- p1 := strings.IndexRune(line, '"')
- if p1 == -1 {
- return nil, errors.New("invalid import line: " + line)
- }
- p1++
- p := strings.IndexRune(line[p1:], '"')
- if p == -1 {
- return nil, errors.New("invalid import line: " + line)
- }
- p2 := p1 + p
- il.pkg = line[p1:p2]
-
- pDot := strings.IndexRune(il.pkg, '.')
- pSlash := strings.IndexRune(il.pkg, '/')
- if pDot != -1 && pDot < pSlash {
- il.group = "domain-package"
- }
- for groupName := range importPackageGroupOrders {
- if groupName == "" {
- continue // skip internal
- }
- if strings.HasPrefix(il.pkg, groupName) {
- il.group = groupName
- }
- }
- return il, nil
- }
-
- type (
- importLineGroup []*importLineParsed
- importLineGroupMap map[string]importLineGroup
- )
-
- func formatGoImports(contentBytes []byte) ([]byte, error) {
- p1 := bytes.Index(contentBytes, importBlockBegin)
- if p1 == -1 {
- return nil, nil
- }
- p1 += len(importBlockBegin)
- p := bytes.Index(contentBytes[p1:], importBlockEnd)
- if p == -1 {
- return nil, nil
- }
- p2 := p1 + p
-
- importGroups := importLineGroupMap{}
- r := bytes.NewBuffer(contentBytes[p1:p2])
- eof := false
- for !eof {
- line, err := r.ReadString('\n')
- eof = err == io.EOF
- if err != nil && !eof {
- return nil, err
- }
- line = strings.TrimSpace(line)
- if line != "" {
- if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "/*") {
- return nil, errInvalidCommentBetweenImports
- }
- importLine, err := parseImportLine(line)
- if err != nil {
- return nil, err
- }
- importGroups[importLine.group] = append(importGroups[importLine.group], importLine)
- }
- }
-
- var groupNames []string
- for groupName, importLines := range importGroups {
- groupNames = append(groupNames, groupName)
- sort.Slice(importLines, func(i, j int) bool {
- return strings.Compare(importLines[i].pkg, importLines[j].pkg) < 0
- })
- }
-
- sort.Slice(groupNames, func(i, j int) bool {
- n1 := groupNames[i]
- n2 := groupNames[j]
- o1 := importPackageGroupOrders[n1]
- o2 := importPackageGroupOrders[n2]
- if o1 != 0 && o2 != 0 {
- return o1 < o2
- }
- if o1 == 0 && o2 == 0 {
- return strings.Compare(n1, n2) < 0
- }
- return o1 != 0
- })
-
- formattedBlock := bytes.Buffer{}
- for _, groupName := range groupNames {
- hasNormalImports := false
- hasDummyImports := false
- // non-dummy import comes first
- for _, importLine := range importGroups[groupName] {
- if strings.HasPrefix(importLine.content, "_") {
- hasDummyImports = true
- } else {
- formattedBlock.WriteString("\t" + importLine.content + "\n")
- hasNormalImports = true
- }
- }
- // dummy (_ "pkg") comes later
- if hasDummyImports {
- if hasNormalImports {
- formattedBlock.WriteString("\n")
- }
- for _, importLine := range importGroups[groupName] {
- if strings.HasPrefix(importLine.content, "_") {
- formattedBlock.WriteString("\t" + importLine.content + "\n")
- }
- }
- }
- formattedBlock.WriteString("\n")
- }
- formattedBlockBytes := bytes.TrimRight(formattedBlock.Bytes(), "\n")
-
- var formattedBytes []byte
- formattedBytes = append(formattedBytes, contentBytes[:p1]...)
- formattedBytes = append(formattedBytes, formattedBlockBytes...)
- formattedBytes = append(formattedBytes, contentBytes[p2:]...)
- return formattedBytes, nil
- }
-
- // FormatGoImports format the imports by our rules (see unit tests)
- func FormatGoImports(file string, doWriteFile bool) error {
- f, err := os.Open(file)
- if err != nil {
- return err
- }
- var contentBytes []byte
- {
- defer f.Close()
- contentBytes, err = io.ReadAll(f)
- if err != nil {
- return err
- }
- }
- formattedBytes, err := formatGoImports(contentBytes)
- if err != nil {
- return err
- }
- if formattedBytes == nil {
- return nil
- }
- if bytes.Equal(contentBytes, formattedBytes) {
- return nil
- }
-
- if doWriteFile {
- f, err = os.OpenFile(file, os.O_TRUNC|os.O_WRONLY, 0o644)
- if err != nil {
- return err
- }
- defer f.Close()
- _, err = f.Write(formattedBytes)
- return err
- }
-
- return err
- }
|