You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

template_repo.go 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. package generator
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io/ioutil"
  7. "os"
  8. "path"
  9. "path/filepath"
  10. "strings"
  11. "text/template"
  12. "text/template/parse"
  13. "log"
  14. "github.com/go-openapi/inflect"
  15. "github.com/go-openapi/swag"
  16. "github.com/kr/pretty"
  17. )
  18. var templates *Repository
  19. // FuncMap is a map with default functions for use n the templates.
  20. // These are available in every template
  21. var FuncMap template.FuncMap = map[string]interface{}{
  22. "pascalize": pascalize,
  23. "camelize": swag.ToJSONName,
  24. "varname": golang.MangleVarName,
  25. "humanize": swag.ToHumanNameLower,
  26. "snakize": golang.MangleFileName,
  27. "toPackagePath": func(name string) string {
  28. return filepath.FromSlash(golang.ManglePackagePath(name, ""))
  29. },
  30. "toPackage": func(name string) string {
  31. return golang.ManglePackagePath(name, "")
  32. },
  33. "toPackageName": func(name string) string {
  34. return golang.ManglePackageName(name, "")
  35. },
  36. "dasherize": swag.ToCommandName,
  37. "pluralizeFirstWord": func(arg string) string {
  38. sentence := strings.Split(arg, " ")
  39. if len(sentence) == 1 {
  40. return inflect.Pluralize(arg)
  41. }
  42. return inflect.Pluralize(sentence[0]) + " " + strings.Join(sentence[1:], " ")
  43. },
  44. "json": asJSON,
  45. "prettyjson": asPrettyJSON,
  46. "hasInsecure": func(arg []string) bool {
  47. return swag.ContainsStringsCI(arg, "http") || swag.ContainsStringsCI(arg, "ws")
  48. },
  49. "hasSecure": func(arg []string) bool {
  50. return swag.ContainsStringsCI(arg, "https") || swag.ContainsStringsCI(arg, "wss")
  51. },
  52. // TODO: simplify redundant functions
  53. "stripPackage": func(str, pkg string) string {
  54. parts := strings.Split(str, ".")
  55. strlen := len(parts)
  56. if strlen > 0 {
  57. return parts[strlen-1]
  58. }
  59. return str
  60. },
  61. "dropPackage": func(str string) string {
  62. parts := strings.Split(str, ".")
  63. strlen := len(parts)
  64. if strlen > 0 {
  65. return parts[strlen-1]
  66. }
  67. return str
  68. },
  69. "upper": strings.ToUpper,
  70. "contains": func(coll []string, arg string) bool {
  71. for _, v := range coll {
  72. if v == arg {
  73. return true
  74. }
  75. }
  76. return false
  77. },
  78. "padSurround": func(entry, padWith string, i, ln int) string {
  79. var res []string
  80. if i > 0 {
  81. for j := 0; j < i; j++ {
  82. res = append(res, padWith)
  83. }
  84. }
  85. res = append(res, entry)
  86. tot := ln - i - 1
  87. for j := 0; j < tot; j++ {
  88. res = append(res, padWith)
  89. }
  90. return strings.Join(res, ",")
  91. },
  92. "joinFilePath": filepath.Join,
  93. "comment": func(str string) string {
  94. lines := strings.Split(str, "\n")
  95. return (strings.Join(lines, "\n// "))
  96. },
  97. "blockcomment": func(str string) string {
  98. return strings.Replace(str, "*/", "[*]/", -1)
  99. },
  100. "inspect": pretty.Sprint,
  101. "cleanPath": path.Clean,
  102. "mediaTypeName": func(orig string) string {
  103. return strings.SplitN(orig, ";", 2)[0]
  104. },
  105. "goSliceInitializer": goSliceInitializer,
  106. "hasPrefix": strings.HasPrefix,
  107. "stringContains": strings.Contains,
  108. }
  109. func init() {
  110. templates = NewRepository(FuncMap)
  111. }
  112. var assets = map[string][]byte{
  113. "validation/primitive.gotmpl": MustAsset("templates/validation/primitive.gotmpl"),
  114. "validation/customformat.gotmpl": MustAsset("templates/validation/customformat.gotmpl"),
  115. "docstring.gotmpl": MustAsset("templates/docstring.gotmpl"),
  116. "validation/structfield.gotmpl": MustAsset("templates/validation/structfield.gotmpl"),
  117. "modelvalidator.gotmpl": MustAsset("templates/modelvalidator.gotmpl"),
  118. "structfield.gotmpl": MustAsset("templates/structfield.gotmpl"),
  119. "tupleserializer.gotmpl": MustAsset("templates/tupleserializer.gotmpl"),
  120. "additionalpropertiesserializer.gotmpl": MustAsset("templates/additionalpropertiesserializer.gotmpl"),
  121. "schematype.gotmpl": MustAsset("templates/schematype.gotmpl"),
  122. "schemabody.gotmpl": MustAsset("templates/schemabody.gotmpl"),
  123. "schema.gotmpl": MustAsset("templates/schema.gotmpl"),
  124. "schemavalidator.gotmpl": MustAsset("templates/schemavalidator.gotmpl"),
  125. "model.gotmpl": MustAsset("templates/model.gotmpl"),
  126. "header.gotmpl": MustAsset("templates/header.gotmpl"),
  127. "swagger_json_embed.gotmpl": MustAsset("templates/swagger_json_embed.gotmpl"),
  128. "server/parameter.gotmpl": MustAsset("templates/server/parameter.gotmpl"),
  129. "server/urlbuilder.gotmpl": MustAsset("templates/server/urlbuilder.gotmpl"),
  130. "server/responses.gotmpl": MustAsset("templates/server/responses.gotmpl"),
  131. "server/operation.gotmpl": MustAsset("templates/server/operation.gotmpl"),
  132. "server/builder.gotmpl": MustAsset("templates/server/builder.gotmpl"),
  133. "server/server.gotmpl": MustAsset("templates/server/server.gotmpl"),
  134. "server/configureapi.gotmpl": MustAsset("templates/server/configureapi.gotmpl"),
  135. "server/main.gotmpl": MustAsset("templates/server/main.gotmpl"),
  136. "server/doc.gotmpl": MustAsset("templates/server/doc.gotmpl"),
  137. "client/parameter.gotmpl": MustAsset("templates/client/parameter.gotmpl"),
  138. "client/response.gotmpl": MustAsset("templates/client/response.gotmpl"),
  139. "client/client.gotmpl": MustAsset("templates/client/client.gotmpl"),
  140. "client/facade.gotmpl": MustAsset("templates/client/facade.gotmpl"),
  141. }
  142. var protectedTemplates = map[string]bool{
  143. "schemabody": true,
  144. "privtuplefield": true,
  145. "withoutBaseTypeBody": true,
  146. "swaggerJsonEmbed": true,
  147. "validationCustomformat": true,
  148. "tuplefield": true,
  149. "header": true,
  150. "withBaseTypeBody": true,
  151. "primitivefieldvalidator": true,
  152. "mapvalidator": true,
  153. "propertyValidationDocString": true,
  154. "typeSchemaType": true,
  155. "docstring": true,
  156. "dereffedSchemaType": true,
  157. "model": true,
  158. "modelvalidator": true,
  159. "privstructfield": true,
  160. "schemavalidator": true,
  161. "tuplefieldIface": true,
  162. "tupleSerializer": true,
  163. "tupleserializer": true,
  164. "schemaSerializer": true,
  165. "propertyvalidator": true,
  166. "structfieldIface": true,
  167. "schemaBody": true,
  168. "objectvalidator": true,
  169. "schematype": true,
  170. "additionalpropertiesserializer": true,
  171. "slicevalidator": true,
  172. "validationStructfield": true,
  173. "validationPrimitive": true,
  174. "schemaType": true,
  175. "subTypeBody": true,
  176. "schema": true,
  177. "additionalPropertiesSerializer": true,
  178. "serverDoc": true,
  179. "structfield": true,
  180. "hasDiscriminatedSerializer": true,
  181. "discriminatedSerializer": true,
  182. }
  183. // AddFile adds a file to the default repository. It will create a new template based on the filename.
  184. // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
  185. // directory separators and Camelcase the next letter.
  186. // e.g validation/primitive.gotmpl will become validationPrimitive
  187. //
  188. // If the file contains a definition for a template that is protected the whole file will not be added
  189. func AddFile(name, data string) error {
  190. return templates.addFile(name, data, false)
  191. }
  192. func asJSON(data interface{}) (string, error) {
  193. b, err := json.Marshal(data)
  194. if err != nil {
  195. return "", err
  196. }
  197. return string(b), nil
  198. }
  199. func asPrettyJSON(data interface{}) (string, error) {
  200. b, err := json.MarshalIndent(data, "", " ")
  201. if err != nil {
  202. return "", err
  203. }
  204. return string(b), nil
  205. }
  206. func goSliceInitializer(data interface{}) (string, error) {
  207. // goSliceInitializer constructs a Go literal initializer from interface{} literals.
  208. // e.g. []interface{}{"a", "b"} is transformed in {"a","b",}
  209. // e.g. map[string]interface{}{ "a": "x", "b": "y"} is transformed in {"a":"x","b":"y",}.
  210. //
  211. // NOTE: this is currently used to construct simple slice intializers for default values.
  212. // This allows for nicer slice initializers for slices of primitive types and avoid systematic use for json.Unmarshal().
  213. b, err := json.Marshal(data)
  214. if err != nil {
  215. return "", err
  216. }
  217. return strings.Replace(strings.Replace(strings.Replace(string(b), "}", ",}", -1), "[", "{", -1), "]", ",}", -1), nil
  218. }
  219. // NewRepository creates a new template repository with the provided functions defined
  220. func NewRepository(funcs template.FuncMap) *Repository {
  221. repo := Repository{
  222. files: make(map[string]string),
  223. templates: make(map[string]*template.Template),
  224. funcs: funcs,
  225. }
  226. if repo.funcs == nil {
  227. repo.funcs = make(template.FuncMap)
  228. }
  229. return &repo
  230. }
  231. // Repository is the repository for the generator templates
  232. type Repository struct {
  233. files map[string]string
  234. templates map[string]*template.Template
  235. funcs template.FuncMap
  236. }
  237. // LoadDefaults will load the embedded templates
  238. func (t *Repository) LoadDefaults() {
  239. for name, asset := range assets {
  240. if err := t.addFile(name, string(asset), true); err != nil {
  241. log.Fatal(err)
  242. }
  243. }
  244. }
  245. // LoadDir will walk the specified path and add each .gotmpl file it finds to the repository
  246. func (t *Repository) LoadDir(templatePath string) error {
  247. err := filepath.Walk(templatePath, func(path string, info os.FileInfo, err error) error {
  248. if strings.HasSuffix(path, ".gotmpl") {
  249. if assetName, e := filepath.Rel(templatePath, path); e == nil {
  250. if data, e := ioutil.ReadFile(path); e == nil {
  251. if ee := t.AddFile(assetName, string(data)); ee != nil {
  252. // Fatality is decided by caller
  253. // log.Fatal(ee)
  254. return fmt.Errorf("could not add template: %v", ee)
  255. }
  256. }
  257. // Non-readable files are skipped
  258. }
  259. }
  260. if err != nil {
  261. return err
  262. }
  263. // Non-template files are skipped
  264. return nil
  265. })
  266. if err != nil {
  267. return fmt.Errorf("could not complete template processing in directory \"%s\": %v", templatePath, err)
  268. }
  269. return nil
  270. }
  271. // LoadContrib loads template from contrib directory
  272. func (t *Repository) LoadContrib(name string) error {
  273. log.Printf("loading contrib %s", name)
  274. const pathPrefix = "templates/contrib/"
  275. basePath := pathPrefix + name
  276. filesAdded := 0
  277. for _, aname := range AssetNames() {
  278. if !strings.HasSuffix(aname, ".gotmpl") {
  279. continue
  280. }
  281. if strings.HasPrefix(aname, basePath) {
  282. target := aname[len(basePath)+1:]
  283. err := t.addFile(target, string(MustAsset(aname)), true)
  284. if err != nil {
  285. return err
  286. }
  287. log.Printf("added contributed template %s from %s", target, aname)
  288. filesAdded++
  289. }
  290. }
  291. if filesAdded == 0 {
  292. return fmt.Errorf("no files added from template: %s", name)
  293. }
  294. return nil
  295. }
  296. func (t *Repository) addFile(name, data string, allowOverride bool) error {
  297. fileName := name
  298. name = swag.ToJSONName(strings.TrimSuffix(name, ".gotmpl"))
  299. templ, err := template.New(name).Funcs(t.funcs).Parse(data)
  300. if err != nil {
  301. return fmt.Errorf("failed to load template %s: %v", name, err)
  302. }
  303. // check if any protected templates are defined
  304. if !allowOverride {
  305. for _, template := range templ.Templates() {
  306. if protectedTemplates[template.Name()] {
  307. return fmt.Errorf("cannot overwrite protected template %s", template.Name())
  308. }
  309. }
  310. }
  311. // Add each defined template into the cache
  312. for _, template := range templ.Templates() {
  313. t.files[template.Name()] = fileName
  314. t.templates[template.Name()] = template.Lookup(template.Name())
  315. }
  316. return nil
  317. }
  318. // MustGet a template by name, panics when fails
  319. func (t *Repository) MustGet(name string) *template.Template {
  320. tpl, err := t.Get(name)
  321. if err != nil {
  322. panic(err)
  323. }
  324. return tpl
  325. }
  326. // AddFile adds a file to the repository. It will create a new template based on the filename.
  327. // It trims the .gotmpl from the end and converts the name using swag.ToJSONName. This will strip
  328. // directory separators and Camelcase the next letter.
  329. // e.g validation/primitive.gotmpl will become validationPrimitive
  330. //
  331. // If the file contains a definition for a template that is protected the whole file will not be added
  332. func (t *Repository) AddFile(name, data string) error {
  333. return t.addFile(name, data, false)
  334. }
  335. func findDependencies(n parse.Node) []string {
  336. var deps []string
  337. depMap := make(map[string]bool)
  338. if n == nil {
  339. return deps
  340. }
  341. switch node := n.(type) {
  342. case *parse.ListNode:
  343. if node != nil && node.Nodes != nil {
  344. for _, nn := range node.Nodes {
  345. for _, dep := range findDependencies(nn) {
  346. depMap[dep] = true
  347. }
  348. }
  349. }
  350. case *parse.IfNode:
  351. for _, dep := range findDependencies(node.BranchNode.List) {
  352. depMap[dep] = true
  353. }
  354. for _, dep := range findDependencies(node.BranchNode.ElseList) {
  355. depMap[dep] = true
  356. }
  357. case *parse.RangeNode:
  358. for _, dep := range findDependencies(node.BranchNode.List) {
  359. depMap[dep] = true
  360. }
  361. for _, dep := range findDependencies(node.BranchNode.ElseList) {
  362. depMap[dep] = true
  363. }
  364. case *parse.WithNode:
  365. for _, dep := range findDependencies(node.BranchNode.List) {
  366. depMap[dep] = true
  367. }
  368. for _, dep := range findDependencies(node.BranchNode.ElseList) {
  369. depMap[dep] = true
  370. }
  371. case *parse.TemplateNode:
  372. depMap[node.Name] = true
  373. }
  374. for dep := range depMap {
  375. deps = append(deps, dep)
  376. }
  377. return deps
  378. }
  379. func (t *Repository) flattenDependencies(templ *template.Template, dependencies map[string]bool) map[string]bool {
  380. if dependencies == nil {
  381. dependencies = make(map[string]bool)
  382. }
  383. deps := findDependencies(templ.Tree.Root)
  384. for _, d := range deps {
  385. if _, found := dependencies[d]; !found {
  386. dependencies[d] = true
  387. if tt := t.templates[d]; tt != nil {
  388. dependencies = t.flattenDependencies(tt, dependencies)
  389. }
  390. }
  391. dependencies[d] = true
  392. }
  393. return dependencies
  394. }
  395. func (t *Repository) addDependencies(templ *template.Template) (*template.Template, error) {
  396. name := templ.Name()
  397. deps := t.flattenDependencies(templ, nil)
  398. for dep := range deps {
  399. if dep == "" {
  400. continue
  401. }
  402. tt := templ.Lookup(dep)
  403. // Check if we have it
  404. if tt == nil {
  405. tt = t.templates[dep]
  406. // Still don't have it, return an error
  407. if tt == nil {
  408. return templ, fmt.Errorf("could not find template %s", dep)
  409. }
  410. var err error
  411. // Add it to the parse tree
  412. templ, err = templ.AddParseTree(dep, tt.Tree)
  413. if err != nil {
  414. return templ, fmt.Errorf("dependency error: %v", err)
  415. }
  416. }
  417. }
  418. return templ.Lookup(name), nil
  419. }
  420. // Get will return the named template from the repository, ensuring that all dependent templates are loaded.
  421. // It will return an error if a dependent template is not defined in the repository.
  422. func (t *Repository) Get(name string) (*template.Template, error) {
  423. templ, found := t.templates[name]
  424. if !found {
  425. return templ, fmt.Errorf("template doesn't exist %s", name)
  426. }
  427. return t.addDependencies(templ)
  428. }
  429. // DumpTemplates prints out a dump of all the defined templates, where they are defined and what their dependencies are.
  430. func (t *Repository) DumpTemplates() {
  431. buf := bytes.NewBuffer(nil)
  432. fmt.Fprintln(buf, "\n# Templates")
  433. for name, templ := range t.templates {
  434. fmt.Fprintf(buf, "## %s\n", name)
  435. fmt.Fprintf(buf, "Defined in `%s`\n", t.files[name])
  436. if deps := findDependencies(templ.Tree.Root); len(deps) > 0 {
  437. fmt.Fprintf(buf, "####requires \n - %v\n\n\n", strings.Join(deps, "\n - "))
  438. }
  439. fmt.Fprintln(buf, "\n---")
  440. }
  441. log.Println(buf.String())
  442. }