]> source.dussan.org Git - gitea.git/commitdiff
Fix NuGet Package API for $filter with Id equality (#31188) (#31242)
authorThomas Desveaux <thomas.desveaux@dont-nod.com>
Tue, 4 Jun 2024 11:56:59 +0000 (13:56 +0200)
committerGitHub <noreply@github.com>
Tue, 4 Jun 2024 11:56:59 +0000 (14:56 +0300)
Backport #31188

Fixes issue when running `choco info pkgname` where `pkgname` is also a
substring of another package Id.

Relates to #31168

---

This might fix the issue linked, but I'd like to test it with more choco
commands before closing the issue in case I find other problems if
that's ok.
I'm pretty inexperienced with Go, so feel free to nitpick things.

Not sure I handled
[this](https://github.com/tdesveaux/gitea/blob/70f87e11b5caf1ee441ae71c7eba1831f45897d4/routers/api/packages/nuget/nuget.go#L135-L137)
in the best way, so looking for feedback on if I should fix the
underlying issue (`nil` might be a better default for `Value`?).

Co-authored-by: KN4CK3R <admin@oldschoolhack.me>
routers/api/packages/nuget/nuget.go
tests/integration/api_packages_nuget_test.go

index 26b0ae226e45b7a751d2b8d865039f71f06ca9ef..3633d0d00704c5be2712ff4a819686d9e8c8570f 100644 (file)
@@ -96,20 +96,34 @@ func FeedCapabilityResource(ctx *context.Context) {
        xmlResponse(ctx, http.StatusOK, Metadata)
 }
 
-var searchTermExtract = regexp.MustCompile(`'([^']+)'`)
+var (
+       searchTermExtract = regexp.MustCompile(`'([^']+)'`)
+       searchTermExact   = regexp.MustCompile(`\s+eq\s+'`)
+)
 
-func getSearchTerm(ctx *context.Context) string {
+func getSearchTerm(ctx *context.Context) packages_model.SearchValue {
        searchTerm := strings.Trim(ctx.FormTrim("searchTerm"), "'")
-       if searchTerm == "" {
-               // $filter contains a query like:
-               // (((Id ne null) and substringof('microsoft',tolower(Id)))
-               // We don't support these queries, just extract the search term.
-               match := searchTermExtract.FindStringSubmatch(ctx.FormTrim("$filter"))
-               if len(match) == 2 {
-                       searchTerm = strings.TrimSpace(match[1])
+       if searchTerm != "" {
+               return packages_model.SearchValue{
+                       Value:      searchTerm,
+                       ExactMatch: false,
+               }
+       }
+
+       // $filter contains a query like:
+       // (((Id ne null) and substringof('microsoft',tolower(Id)))
+       // https://www.odata.org/documentation/odata-version-2-0/uri-conventions/ section 4.5
+       // We don't support these queries, just extract the search term.
+       filter := ctx.FormTrim("$filter")
+       match := searchTermExtract.FindStringSubmatch(filter)
+       if len(match) == 2 {
+               return packages_model.SearchValue{
+                       Value:      strings.TrimSpace(match[1]),
+                       ExactMatch: searchTermExact.MatchString(filter),
                }
        }
-       return searchTerm
+
+       return packages_model.SearchValue{}
 }
 
 // https://github.com/NuGet/NuGet.Client/blob/dev/src/NuGet.Core/NuGet.Protocol/LegacyFeed/V2FeedQueryBuilder.cs
@@ -118,11 +132,9 @@ func SearchServiceV2(ctx *context.Context) {
        paginator := db.NewAbsoluteListOptions(skip, take)
 
        pvs, total, err := packages_model.SearchLatestVersions(ctx, &packages_model.PackageSearchOptions{
-               OwnerID: ctx.Package.Owner.ID,
-               Type:    packages_model.TypeNuGet,
-               Name: packages_model.SearchValue{
-                       Value: getSearchTerm(ctx),
-               },
+               OwnerID:    ctx.Package.Owner.ID,
+               Type:       packages_model.TypeNuGet,
+               Name:       getSearchTerm(ctx),
                IsInternal: optional.Some(false),
                Paginator:  paginator,
        })
@@ -169,10 +181,8 @@ func SearchServiceV2(ctx *context.Context) {
 // http://docs.oasis-open.org/odata/odata/v4.0/errata03/os/complete/part2-url-conventions/odata-v4.0-errata03-os-part2-url-conventions-complete.html#_Toc453752351
 func SearchServiceV2Count(ctx *context.Context) {
        count, err := nuget_model.CountPackages(ctx, &packages_model.PackageSearchOptions{
-               OwnerID: ctx.Package.Owner.ID,
-               Name: packages_model.SearchValue{
-                       Value: getSearchTerm(ctx),
-               },
+               OwnerID:    ctx.Package.Owner.ID,
+               Name:       getSearchTerm(ctx),
                IsInternal: optional.Some(false),
        })
        if err != nil {
index 83947ff9671ec518f424c17db15a86a3081e232c..630b4de3f92b6b4472a4ea0e56e3024e5c4b8485 100644 (file)
@@ -429,22 +429,33 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
 
        t.Run("SearchService", func(t *testing.T) {
                cases := []struct {
-                       Query           string
-                       Skip            int
-                       Take            int
-                       ExpectedTotal   int64
-                       ExpectedResults int
+                       Query              string
+                       Skip               int
+                       Take               int
+                       ExpectedTotal      int64
+                       ExpectedResults    int
+                       ExpectedExactMatch bool
                }{
-                       {"", 0, 0, 1, 1},
-                       {"", 0, 10, 1, 1},
-                       {"gitea", 0, 10, 0, 0},
-                       {"test", 0, 10, 1, 1},
-                       {"test", 1, 10, 1, 0},
+                       {"", 0, 0, 4, 4, false},
+                       {"", 0, 10, 4, 4, false},
+                       {"gitea", 0, 10, 0, 0, false},
+                       {"test", 0, 10, 1, 1, false},
+                       {"test", 1, 10, 1, 0, false},
+                       {"almost.similar", 0, 0, 3, 3, true},
                }
 
-               req := NewRequestWithBody(t, "PUT", url, createPackage(packageName, "1.0.99")).
-                       AddBasicAuth(user.Name)
-               MakeRequest(t, req, http.StatusCreated)
+               fakePackages := []string{
+                       packageName,
+                       "almost.similar.dependency",
+                       "almost.similar",
+                       "almost.similar.dependant",
+               }
+
+               for _, fakePackageName := range fakePackages {
+                       req := NewRequestWithBody(t, "PUT", url, createPackage(fakePackageName, "1.0.99")).
+                               AddBasicAuth(user.Name)
+                       MakeRequest(t, req, http.StatusCreated)
+               }
 
                t.Run("v2", func(t *testing.T) {
                        t.Run("Search()", func(t *testing.T) {
@@ -491,6 +502,63 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
                                }
                        })
 
+                       t.Run("Packages()", func(t *testing.T) {
+                               defer tests.PrintCurrentTest(t)()
+
+                               t.Run("substringof", func(t *testing.T) {
+                                       defer tests.PrintCurrentTest(t)()
+
+                                       for i, c := range cases {
+                                               req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+                                                       AddBasicAuth(user.Name)
+                                               resp := MakeRequest(t, req, http.StatusOK)
+
+                                               var result FeedResponse
+                                               decodeXML(t, resp, &result)
+
+                                               assert.Equal(t, c.ExpectedTotal, result.Count, "case %d: unexpected total hits", i)
+                                               assert.Len(t, result.Entries, c.ExpectedResults, "case %d: unexpected result count", i)
+
+                                               req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+                                                       AddBasicAuth(user.Name)
+                                               resp = MakeRequest(t, req, http.StatusOK)
+
+                                               assert.Equal(t, strconv.FormatInt(c.ExpectedTotal, 10), resp.Body.String(), "case %d: unexpected total hits", i)
+                                       }
+                               })
+
+                               t.Run("IdEq", func(t *testing.T) {
+                                       defer tests.PrintCurrentTest(t)()
+
+                                       for i, c := range cases {
+                                               if c.Query == "" {
+                                                       // Ignore the `tolower(Id) eq ''` as it's unlikely to happen
+                                                       continue
+                                               }
+                                               req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+                                                       AddBasicAuth(user.Name)
+                                               resp := MakeRequest(t, req, http.StatusOK)
+
+                                               var result FeedResponse
+                                               decodeXML(t, resp, &result)
+
+                                               expectedCount := 0
+                                               if c.ExpectedExactMatch {
+                                                       expectedCount = 1
+                                               }
+
+                                               assert.Equal(t, int64(expectedCount), result.Count, "case %d: unexpected total hits", i)
+                                               assert.Len(t, result.Entries, expectedCount, "case %d: unexpected result count", i)
+
+                                               req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+                                                       AddBasicAuth(user.Name)
+                                               resp = MakeRequest(t, req, http.StatusOK)
+
+                                               assert.Equal(t, strconv.FormatInt(int64(expectedCount), 10), resp.Body.String(), "case %d: unexpected total hits", i)
+                                       }
+                               })
+                       })
+
                        t.Run("Next", func(t *testing.T) {
                                req := NewRequest(t, "GET", fmt.Sprintf("%s/Search()?searchTerm='test'&$skip=0&$top=1", url)).
                                        AddBasicAuth(user.Name)
@@ -548,9 +616,11 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
                        })
                })
 
-               req = NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, packageName, "1.0.99")).
-                       AddBasicAuth(user.Name)
-               MakeRequest(t, req, http.StatusNoContent)
+               for _, fakePackageName := range fakePackages {
+                       req := NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, fakePackageName, "1.0.99")).
+                               AddBasicAuth(user.Name)
+                       MakeRequest(t, req, http.StatusNoContent)
+               }
        })
 
        t.Run("RegistrationService", func(t *testing.T) {