diff --git a/api/dbv1/tracks.go b/api/dbv1/tracks.go index 542a4338..0f2357ab 100644 --- a/api/dbv1/tracks.go +++ b/api/dbv1/tracks.go @@ -50,6 +50,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] userIds := []int32{} trackIds := make([]int32, 0, len(rawTracks)) + ownerByTrack := make(map[int32]int32, len(rawTracks)) collectSplitUserIds := func(usage *AccessGate) { if usage == nil || usage.UsdcPurchase == nil { return @@ -62,6 +63,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] for _, rawTrack := range rawTracks { userIds = append(userIds, rawTrack.UserID) trackIds = append(trackIds, rawTrack.TrackID) + ownerByTrack[rawTrack.TrackID] = rawTrack.UserID var remixOf RemixOf json.Unmarshal(rawTrack.RemixOf, &remixOf) @@ -169,7 +171,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] // Resolve accepted collaborators (order preserved from the query). collaborators := []User{} - for _, cid := range collaboratorsByTrack[rawTrack.TrackID] { + for _, cid := range uniqueCollaboratorIDs(collaboratorsByTrack[rawTrack.TrackID], ownerByTrack[rawTrack.TrackID]) { if cu, ok := userMap[cid]; ok { collaborators = append(collaborators, cu) } @@ -177,7 +179,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] // Resolve pending collaborators (only present for the owner's own tracks). pendingCollaborators := []User{} - for _, cid := range pendingByTrack[rawTrack.TrackID] { + for _, cid := range uniqueCollaboratorIDs(pendingByTrack[rawTrack.TrackID], ownerByTrack[rawTrack.TrackID]) { if cu, ok := userMap[cid]; ok { pendingCollaborators = append(pendingCollaborators, cu) } @@ -264,3 +266,20 @@ func (q *Queries) Tracks(ctx context.Context, arg TracksParams) ([]Track, error) return tracks, nil } + +func uniqueCollaboratorIDs(collaboratorIDs []int32, ownerID int32) []int32 { + if len(collaboratorIDs) == 0 { + return nil + } + + seen := map[int32]struct{}{ownerID: {}} + uniqueIDs := make([]int32, 0, len(collaboratorIDs)) + for _, collaboratorID := range collaboratorIDs { + if _, ok := seen[collaboratorID]; ok { + continue + } + seen[collaboratorID] = struct{}{} + uniqueIDs = append(uniqueIDs, collaboratorID) + } + return uniqueIDs +} diff --git a/api/dbv1/tracks_test.go b/api/dbv1/tracks_test.go new file mode 100644 index 00000000..2a0c2b38 --- /dev/null +++ b/api/dbv1/tracks_test.go @@ -0,0 +1,16 @@ +package dbv1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUniqueCollaboratorIDs(t *testing.T) { + assert.Equal( + t, + []int32{2, 3, 4}, + uniqueCollaboratorIDs([]int32{2, 2, 1, 3, 3, 4}, 1), + ) + assert.Nil(t, uniqueCollaboratorIDs(nil, 1)) +} diff --git a/api/v1_track_collaborators_test.go b/api/v1_track_collaborators_test.go index c2eee2b5..803be47d 100644 --- a/api/v1_track_collaborators_test.go +++ b/api/v1_track_collaborators_test.go @@ -31,6 +31,10 @@ func seedCollaborators(t *testing.T, app *ApiServer) { func TestTrackCollaboratorsEmbeddedOnTrack(t *testing.T) { app := testAppWithFixtures(t) seedCollaborators(t, app) + now := time.Now() + database.SeedTable(app.pool.Replicas[0], "track_collaborators", []map[string]any{ + {"track_id": 700, "collaborator_user_id": 500, "invited_by": 500, "status": "accepted", "created_at": now, "updated_at": now}, + }) var resp struct { Data []dbv1.Track @@ -42,6 +46,7 @@ func TestTrackCollaboratorsEmbeddedOnTrack(t *testing.T) { "data.3.id": trashid.MustEncodeHashID(700), "data.3.collaborators.0.handle": "rayjacobson", }) + assert.Len(t, resp.Data[3].Collaborators, 1, "owner should not be embedded as their own collaborator") // Non-collaborated tracks carry an empty array, not null. assert.Contains(t, string(body), `"collaborators":[]`) }