Quellcode durchsuchen

Add API parameter to filter entries by category

Frédéric Guillot vor 6 Jahren
Ursprung
Commit
e878dca3d7
6 geänderte Dateien mit 58 neuen und 17 gelöschten Zeilen
  1. 10 5
      api/entry.go
  2. 2 1
      client/core.go
  3. 1 1
      storage/entry_query_builder.go
  4. 35 1
      tests/entry_test.go
  5. 2 2
      tests/subscription_test.go
  6. 8 7
      tests/tests.go

+ 10 - 5
api/entry.go

@@ -164,7 +164,7 @@ func (h *handler) getEntries(w http.ResponseWriter, r *http.Request) {
 func (h *handler) setEntryStatus(w http.ResponseWriter, r *http.Request) {
 	entryIDs, status, err := decodeEntryStatusPayload(r.Body)
 	if err != nil {
-		json.BadRequest(w , r, errors.New("Invalid JSON payload"))
+		json.BadRequest(w, r, errors.New("Invalid JSON payload"))
 		return
 	}
 
@@ -193,25 +193,30 @@ func (h *handler) toggleBookmark(w http.ResponseWriter, r *http.Request) {
 
 func configureFilters(builder *storage.EntryQueryBuilder, r *http.Request) {
 	beforeEntryID := request.QueryInt64Param(r, "before_entry_id", 0)
-	if beforeEntryID != 0 {
+	if beforeEntryID > 0 {
 		builder.BeforeEntryID(beforeEntryID)
 	}
 
 	afterEntryID := request.QueryInt64Param(r, "after_entry_id", 0)
-	if afterEntryID != 0 {
+	if afterEntryID > 0 {
 		builder.AfterEntryID(afterEntryID)
 	}
 
 	beforeTimestamp := request.QueryInt64Param(r, "before", 0)
-	if beforeTimestamp != 0 {
+	if beforeTimestamp > 0 {
 		builder.BeforeDate(time.Unix(beforeTimestamp, 0))
 	}
 
 	afterTimestamp := request.QueryInt64Param(r, "after", 0)
-	if afterTimestamp != 0 {
+	if afterTimestamp > 0 {
 		builder.AfterDate(time.Unix(afterTimestamp, 0))
 	}
 
+	categoryID := request.QueryInt64Param(r, "category_id", 0)
+	if categoryID > 0 {
+		builder.WithCategoryID(categoryID)
+	}
+
 	if request.HasQueryParam(r, "starred") {
 		builder.WithStarred()
 	}

+ 2 - 1
client/core.go

@@ -48,7 +48,7 @@ type UserModification struct {
 // Users represents a list of users.
 type Users []User
 
-// Category represents a category in the system.
+// Category represents a feed category.
 type Category struct {
 	ID     int64  `json:"id,omitempty"`
 	Title  string `json:"title,omitempty"`
@@ -169,6 +169,7 @@ type Filter struct {
 	BeforeEntryID int64
 	AfterEntryID  int64
 	Search        string
+	CategoryID    int64
 }
 
 // EntryResultSet represents the response when fetching entries.

+ 1 - 1
storage/entry_query_builder.go

@@ -103,7 +103,7 @@ func (e *EntryQueryBuilder) WithFeedID(feedID int64) *EntryQueryBuilder {
 
 // WithCategoryID set the categoryID.
 func (e *EntryQueryBuilder) WithCategoryID(categoryID int64) *EntryQueryBuilder {
-	if categoryID != 0 {
+	if categoryID > 0 {
 		e.conditions = append(e.conditions, fmt.Sprintf("f.category_id = $%d", len(e.args)+1))
 		e.args = append(e.args, categoryID)
 	}

+ 35 - 1
tests/entry_test.go

@@ -93,6 +93,40 @@ func TestGetAllEntries(t *testing.T) {
 	}
 }
 
+func TestFilterEntriesByCategory(t *testing.T) {
+	client := createClient(t)
+	category, err := client.CreateCategory("Test Filter by Category")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	feedID, err := client.CreateFeed(testFeedURL, category.ID)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if feedID == 0 {
+		t.Fatalf(`Invalid feed ID, got %q`, feedID)
+	}
+
+	results, err := client.Entries(&miniflux.Filter{CategoryID: category.ID})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if results.Total == 0 {
+		t.Fatalf(`We should have more than one entry`)
+	}
+
+	if results.Entries[0].Feed.Category == nil {
+		t.Fatalf(`The entry feed category should not be nil`)
+	}
+
+	if results.Entries[0].Feed.Category.ID != category.ID {
+		t.Errorf(`Entries should be filtered by category_id=%d`, category.ID)
+	}
+}
+
 func TestSearchEntries(t *testing.T) {
 	client := createClient(t)
 	categories, err := client.Categories()
@@ -100,7 +134,7 @@ func TestSearchEntries(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	feedID, err := client.CreateFeed("https://miniflux.app/feed.xml", categories[0].ID)
+	feedID, err := client.CreateFeed(testFeedURL, categories[0].ID)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 2 - 2
tests/subscription_test.go

@@ -21,8 +21,8 @@ func TestDiscoverSubscriptions(t *testing.T) {
 		t.Fatalf(`Invalid number of subscriptions, got "%v" instead of "%v"`, len(subscriptions), 2)
 	}
 
-	if subscriptions[0].Title != testFeedTitle {
-		t.Fatalf(`Invalid feed title, got "%v" instead of "%v"`, subscriptions[0].Title, testFeedTitle)
+	if subscriptions[0].Title != testSubscriptionTitle {
+		t.Fatalf(`Invalid feed title, got "%v" instead of "%v"`, subscriptions[0].Title, testSubscriptionTitle)
 	}
 
 	if subscriptions[0].Type != "atom" {

+ 8 - 7
tests/tests.go

@@ -15,13 +15,14 @@ import (
 )
 
 const (
-	testBaseURL          = "http://127.0.0.1:8080/"
-	testAdminUsername    = "admin"
-	testAdminPassword    = "test123"
-	testStandardPassword = "secret"
-	testFeedURL          = "https://github.com/miniflux/miniflux/commits/master.atom"
-	testFeedTitle        = "Recent Commits to miniflux:master"
-	testWebsiteURL       = "https://github.com/miniflux/miniflux/commits/master"
+	testBaseURL           = "http://127.0.0.1:8080/"
+	testAdminUsername     = "admin"
+	testAdminPassword     = "test123"
+	testStandardPassword  = "secret"
+	testFeedURL           = "https://miniflux.app/feed.xml"
+	testFeedTitle         = "Miniflux"
+	testSubscriptionTitle = "Miniflux Releases"
+	testWebsiteURL        = "https://miniflux.app/"
 )
 
 func getRandomUsername() string {