Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions api/dbv1/tracks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32]

userIds := []int32{}
trackIds := make([]int32, 0, len(rawTracks))
ownerByTrack := make(map[int32]int32, len(rawTracks))
collectSplitUserIds := func(usage *AccessGate) {
if usage == nil || usage.UsdcPurchase == nil {
return
Expand All @@ -62,6 +63,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32]
for _, rawTrack := range rawTracks {
userIds = append(userIds, rawTrack.UserID)
trackIds = append(trackIds, rawTrack.TrackID)
ownerByTrack[rawTrack.TrackID] = rawTrack.UserID

var remixOf RemixOf
json.Unmarshal(rawTrack.RemixOf, &remixOf)
Expand Down Expand Up @@ -169,15 +171,15 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32]

// Resolve accepted collaborators (order preserved from the query).
collaborators := []User{}
for _, cid := range collaboratorsByTrack[rawTrack.TrackID] {
for _, cid := range uniqueCollaboratorIDs(collaboratorsByTrack[rawTrack.TrackID], ownerByTrack[rawTrack.TrackID]) {
if cu, ok := userMap[cid]; ok {
collaborators = append(collaborators, cu)
}
}

// Resolve pending collaborators (only present for the owner's own tracks).
pendingCollaborators := []User{}
for _, cid := range pendingByTrack[rawTrack.TrackID] {
for _, cid := range uniqueCollaboratorIDs(pendingByTrack[rawTrack.TrackID], ownerByTrack[rawTrack.TrackID]) {
if cu, ok := userMap[cid]; ok {
pendingCollaborators = append(pendingCollaborators, cu)
}
Expand Down Expand Up @@ -264,3 +266,20 @@ func (q *Queries) Tracks(ctx context.Context, arg TracksParams) ([]Track, error)

return tracks, nil
}

func uniqueCollaboratorIDs(collaboratorIDs []int32, ownerID int32) []int32 {
if len(collaboratorIDs) == 0 {
return nil
}

seen := map[int32]struct{}{ownerID: {}}
uniqueIDs := make([]int32, 0, len(collaboratorIDs))
for _, collaboratorID := range collaboratorIDs {
if _, ok := seen[collaboratorID]; ok {
continue
}
seen[collaboratorID] = struct{}{}
uniqueIDs = append(uniqueIDs, collaboratorID)
}
return uniqueIDs
}
16 changes: 16 additions & 0 deletions api/dbv1/tracks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package dbv1

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestUniqueCollaboratorIDs(t *testing.T) {
assert.Equal(
t,
[]int32{2, 3, 4},
uniqueCollaboratorIDs([]int32{2, 2, 1, 3, 3, 4}, 1),
)
assert.Nil(t, uniqueCollaboratorIDs(nil, 1))
}
5 changes: 5 additions & 0 deletions api/v1_track_collaborators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func seedCollaborators(t *testing.T, app *ApiServer) {
func TestTrackCollaboratorsEmbeddedOnTrack(t *testing.T) {
app := testAppWithFixtures(t)
seedCollaborators(t, app)
now := time.Now()
database.SeedTable(app.pool.Replicas[0], "track_collaborators", []map[string]any{
{"track_id": 700, "collaborator_user_id": 500, "invited_by": 500, "status": "accepted", "created_at": now, "updated_at": now},
})

var resp struct {
Data []dbv1.Track
Expand All @@ -42,6 +46,7 @@ func TestTrackCollaboratorsEmbeddedOnTrack(t *testing.T) {
"data.3.id": trashid.MustEncodeHashID(700),
"data.3.collaborators.0.handle": "rayjacobson",
})
assert.Len(t, resp.Data[3].Collaborators, 1, "owner should not be embedded as their own collaborator")
// Non-collaborated tracks carry an empty array, not null.
assert.Contains(t, string(body), `"collaborators":[]`)
}
Expand Down
Loading