diff --git a/.github/workflows/go-lint.yml b/.github/workflows/go.yaml similarity index 75% rename from .github/workflows/go-lint.yml rename to .github/workflows/go.yaml index 1187b31..cd3bf9d 100644 --- a/.github/workflows/go-lint.yml +++ b/.github/workflows/go.yaml @@ -1,4 +1,4 @@ -name: Golang Linting +name: Golang Validation on: push: @@ -29,3 +29,10 @@ jobs: - name: Check Format run: | gofmt -s -l database logging sse *.go + - name: Run Tests + run: | + go test ./database + - name: Run vet + run: | + go vet ./database/ ./logging/ ./sse/ + go vet *.go diff --git a/README.md b/README.md index a8deb9f..15b7eb3 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,13 @@ go mod tidy # format all code according to go standards gofmt -w -s *.go logging sse database + +# run tests (database is the first place we've defined tests) +go test ./database + +# run heuristic validation +go vet ./database/ ./logging/ ./sse/ +go vet *.go ``` ## To-Dos diff --git a/database/database.go b/database/database.go index 423344f..d03bc42 100644 --- a/database/database.go +++ b/database/database.go @@ -4,6 +4,7 @@ import ( "context" "os" "strings" + "testing" "time" "github.com/computersciencehouse/vote/logging" @@ -20,10 +21,16 @@ const ( Updated UpsertResult = 1 ) -var Client = Connect() +var Client *mongo.Client = Connect() var db = "" func Connect() *mongo.Client { + // This always gets invoked on initialisation. bad! it'd be nice if we only did this setup in main rather than components under test. for now we just skip if testing + if testing.Testing() { + logging.Logger.WithFields(logrus.Fields{"module": "database", "method": "Connect"}).Info("testing, not doing db connection, someone should mock this someday") + return nil + } + logging.Logger.WithFields(logrus.Fields{"module": "database", "method": "Connect"}).Info("beginning database connection") ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) diff --git a/database/poll.go b/database/poll.go index fcae1db..4f16afb 100644 --- a/database/poll.go +++ b/database/poll.go @@ -2,11 +2,15 @@ package database import ( "context" + "sort" "time" + "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" + + "github.com/computersciencehouse/vote/logging" ) type Poll struct { @@ -139,32 +143,32 @@ func GetClosedVotedPolls(ctx context.Context, userId string) ([]*Poll, error) { cursor, err := Client.Database(db).Collection("votes").Aggregate(ctx, mongo.Pipeline{ {{ - "$match", bson.D{ - {"userId", userId}, + Key: "$match", Value: bson.D{ + {Key: "userId", Value: userId}, }, }}, {{ - "$lookup", bson.D{ - {"from", "polls"}, - {"localField", "pollId"}, - {"foreignField", "_id"}, - {"as", "polls"}, + Key: "$lookup", Value: bson.D{ + {Key: "from", Value: "polls"}, + {Key: "localField", Value: "pollId"}, + {Key: "foreignField", Value: "_id"}, + {Key: "as", Value: "polls"}, }, }}, {{ - "$unwind", bson.D{ - {"path", "$polls"}, - {"preserveNullAndEmptyArrays", false}, + Key: "$unwind", Value: bson.D{ + {Key: "path", Value: "$polls"}, + {Key: "preserveNullAndEmptyArrays", Value: false}, }, }}, {{ - "$replaceRoot", bson.D{ - {"newRoot", "$polls"}, + Key: "$replaceRoot", Value: bson.D{ + {Key: "newRoot", Value: "$polls"}, }, }}, {{ - "$match", bson.D{ - {"open", false}, + Key: "$match", Value: bson.D{ + {Key: "open", Value: false}, }, }}, }) @@ -178,6 +182,105 @@ func GetClosedVotedPolls(ctx context.Context, userId string) ([]*Poll, error) { return polls, nil } +// calculateRankedResult determines a result for a ranked choice vote +// votesRaw is the RankedVote entries that are returned directly from the database +// The algorithm defined in the Constitution as of 26 Nov 2025 is as follows: +// +// > The winning option is selected outright if it gains more than half the votes +// > cast as a first preference. If not, the option with the fewest number of first +// > preference votes is eliminated and their votes move to the second preference +// > marked on the ballots. This process continues until one option has half of the +// > votes cast and is elected. +// +// The return value consists of a list of voting rounds. Each round contains a +// mapping of the vote options to their vote share for that round. If the vote +// is not decided in a given round, there will be a subsequent round with the +// option that had the fewest votes eliminated, and its votes redistributed. +// +// The last entry in this list is the final round, and the option with the most +// votes in this round is the winner. If all options have the same, then it is +// unfortunately a tie, and the vote is not resolvable, as there is no lowest +// option to eliminate. +func calculateRankedResult(ctx context.Context, votesRaw []RankedVote) ([]map[string]int, error) { + // We want to store those that were eliminated so we don't accidentally reinclude them + eliminated := make([]string, 0) + votes := make([][]string, 0) + finalResult := make([]map[string]int, 0) + + //change ranked votes from a map (which is unordered) to a slice of votes (which is ordered) + //order is from first preference to last preference + for _, vote := range votesRaw { + optionList := orderOptions(ctx, vote.Options) + votes = append(votes, optionList) + } + + round := 0 + // Iterate until we have a winner + for { + round = round + 1 + // Contains candidates to number of votes in this round + tallied := make(map[string]int) + voteCount := 0 + for _, picks := range votes { + // Go over picks until we find a non-eliminated candidate + for _, candidate := range picks { + if !containsValue(eliminated, candidate) { + if _, ok := tallied[candidate]; ok { + tallied[candidate]++ + } else { + tallied[candidate] = 1 + } + voteCount += 1 + break + } + } + } + // Eliminate lowest vote getter + minVote := 1000000 //the smallest number of votes received thus far (to find who is in last) + minPerson := make([]string, 0) //the person(s) with the least votes that need removed + for person, vote := range tallied { + if vote < minVote { // this should always be true round one, to set a true "who is in last" + minVote = vote + minPerson = make([]string, 0) + minPerson = append(minPerson, person) + } else if vote == minVote { + minPerson = append(minPerson, person) + } + } + eliminated = append(eliminated, minPerson...) + finalResult = append(finalResult, tallied) + + // TODO this should probably include some poll identifier + logging.Logger.WithFields(logrus.Fields{"round": round, "tallies": tallied, "threshold": voteCount / 2}).Debug("round report") + + // If one person has all the votes, they win + if len(tallied) == 1 { + break + } + + end := true + for str, val := range tallied { + // if any particular entry is above half remaining votes, they win and it ends + if val > (voteCount / 2) { + finalResult = append(finalResult, map[string]int{str: val}) + end = true + break + } + // Check if all values in tallied are the same + // In that case, it's a tie? + if val != minVote { + end = false + break + } + } + if end { + break + } + } + return finalResult, nil + +} + func (poll *Poll) GetResult(ctx context.Context) ([]map[string]int, error) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() @@ -191,15 +294,15 @@ func (poll *Poll) GetResult(ctx context.Context) ([]map[string]int, error) { pollResult := make(map[string]int) cursor, err := Client.Database(db).Collection("votes").Aggregate(ctx, mongo.Pipeline{ {{ - "$match", bson.D{ - {"pollId", pollId}, + Key: "$match", Value: bson.D{ + {Key: "pollId", Value: pollId}, }, }}, {{ - "$group", bson.D{ - {"_id", "$option"}, - {"count", bson.D{ - {"$sum", 1}, + Key: "$group", Value: bson.D{ + {Key: "_id", Value: "$option"}, + {Key: "count", Value: bson.D{ + {Key: "$sum", Value: 1}, }}, }, }}, @@ -223,14 +326,11 @@ func (poll *Poll) GetResult(ctx context.Context) ([]map[string]int, error) { return finalResult, nil case POLL_TYPE_RANKED: - // We want to store those that were eliminated - eliminated := make([]string, 0) - // Get all votes cursor, err := Client.Database(db).Collection("votes").Aggregate(ctx, mongo.Pipeline{ {{ - "$match", bson.D{ - {"pollId", pollId}, + Key: "$match", Value: bson.D{ + {Key: "pollId", Value: pollId}, }, }}, }) @@ -239,76 +339,7 @@ func (poll *Poll) GetResult(ctx context.Context) ([]map[string]int, error) { } var votesRaw []RankedVote cursor.All(ctx, &votesRaw) - - votes := make([][]string, 0) - - //change ranked votes from a map (which is unordered) to a slice of votes (which is ordered) - //order is from first preference to last preference - for _, vote := range votesRaw { - temp, cf := context.WithTimeout(context.Background(), 1*time.Second) - optionList := orderOptions(vote.Options, temp) - cf() - votes = append(votes, optionList) - } - - // Iterate until we have a winner - for { - // Contains candidates to number of votes in this round - tallied := make(map[string]int) - voteCount := 0 - for _, picks := range votes { - // Go over picks until we find a non-eliminated candidate - for _, candidate := range picks { - if !containsValue(eliminated, candidate) { - if _, ok := tallied[candidate]; ok { - tallied[candidate]++ - } else { - tallied[candidate] = 1 - } - voteCount += 1 - break - } - } - } - // Eliminate lowest vote getter - minVote := 1000000 //the smallest number of votes received thus far (to find who is in last) - minPerson := make([]string, 0) //the person(s) with the least votes that need removed - for person, vote := range tallied { - if vote < minVote { // this should always be true round one, to set a true "who is in last" - minVote = vote - minPerson = make([]string, 0) - minPerson = append(minPerson, person) - } else if vote == minVote { - minPerson = append(minPerson, person) - } - } - eliminated = append(eliminated, minPerson...) - finalResult = append(finalResult, tallied) - // If one person has all the votes, they win - if len(tallied) == 1 { - break - } - - end := true - for str, val := range tallied { - // if any particular entry is above half remaining votes, they win and it ends - if val > (voteCount / 2) { - finalResult = append(finalResult, map[string]int{str: val}) - end = true - break - } - // Check if all values in tallied are the same - // In that case, it's a tie? - if val != minVote { - end = false - break - } - } - if end { - break - } - } - return finalResult, nil + return calculateRankedResult(ctx, votesRaw) } return nil, nil } @@ -322,21 +353,35 @@ func containsValue(slice []string, value string) bool { return false } -func orderOptions(options map[string]int, ctx context.Context) []string { - result := make([]string, 0, len(options)) - order := 1 - for order <= len(options) { - for option, preference := range options { - select { - case <-ctx.Done(): - return make([]string, 0) - default: - if preference == order { - result = append(result, option) - order += 1 - } - } - } +// orderOptions takes a RankedVote's options, and returns an ordered list of +// their choices +// +// it's invalid for a vote to list the same number multiple times, the output +// will vary based on the map ordering of the options, and so is not guaranteed +// to be deterministic +// +// ctx is no longer used, as this function is not expected to hang, but remains +// an argument per golang standards +// +// the return values is the option keys, ordered from lowest to highest +func orderOptions(ctx context.Context, options map[string]int) []string { + // Figure out all the ranks they've listed + var ranks []int = make([]int, len(options)) + reverse_map := make(map[int]string) + i := 0 + for option, rank := range options { + ranks[i] = rank + reverse_map[rank] = option + i += 1 + } + + sort.Ints(ranks) + + // normalise the ranks for counts that don't start at 1 + var choices []string = make([]string, len(ranks)) + for idx, rank := range ranks { + choices[idx] = reverse_map[rank] } - return result + + return choices } diff --git a/database/poll_test.go b/database/poll_test.go new file mode 100644 index 0000000..9787399 --- /dev/null +++ b/database/poll_test.go @@ -0,0 +1,225 @@ +package database + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func makeVotes() []RankedVote { + // so inpyt for this, we want to have option, then a list of ranks. + // am tempted to have some shorthand for generating test cases more easily + return []RankedVote{} +} + +func TestResultCalcs(t *testing.T) { + // for votes, we only need to define options, we don't currently rely on IDs + tests := []struct { + name string + votes []RankedVote + results []map[string]int + err error + }{ + { + name: "Empty Votes", + votes: []RankedVote{ + { + Options: map[string]int{}, + }, + }, + results: []map[string]int{ + {}, + }, + }, + { + name: "1 vote", + votes: []RankedVote{ + { + Options: map[string]int{ + "first": 1, + "second": 2, + "third": 3, + }, + }, + }, + results: []map[string]int{ + { + "first": 1, + }, + }, + }, + { + name: "Tie vote", + votes: []RankedVote{ + { + Options: map[string]int{ + "first": 1, + "second": 2, + }, + }, + { + Options: map[string]int{ + "first": 2, + "second": 1, + }, + }, + }, + results: []map[string]int{ + { + "first": 1, + "second": 1, + }, + }, + }, + { + name: "Several Rounds", + votes: []RankedVote{ + { + Options: map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }, + }, + { + Options: map[string]int{ + "a": 2, + "b": 1, + "c": 3, + }, + }, + { + Options: map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }, + }, + { + Options: map[string]int{ + "a": 2, + "b": 1, + "c": 3, + }, + }, + { + Options: map[string]int{ + "a": 2, + "b": 3, + "c": 1, + }, + }, + }, + results: []map[string]int{ + { + "a": 2, + "b": 2, + "c": 1, + }, + { + "a": 3, + "b": 2, + }, + { + "a": 3, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + results, err := calculateRankedResult(context.Background(), test.votes) + assert.Equal(t, test.results, results) + assert.Equal(t, test.err, err) + }) + } +} + +func TestOrderOptions(t *testing.T) { + tests := []struct { + name string + input map[string]int + output []string + }{ + { + name: "SimpleOrder", + input: map[string]int{ + "one": 1, + "two": 2, + "three": 3, + "four": 4, + }, + output: []string{"one", "two", "three", "four"}, + }, + { + name: "Reversed", + input: map[string]int{ + "one": 4, + "two": 3, + "three": 2, + "four": 1, + }, + output: []string{"four", "three", "two", "one"}, + }, + { + name: "HighStart", + input: map[string]int{ + "one": 2, + "two": 3, + "three": 4, + "four": 5, + }, + output: []string{"one", "two", "three", "four"}, + }, + { + name: "LowStart", + input: map[string]int{ + "one": 0, + "two": 1, + "three": 2, + "four": 3, + }, + output: []string{"one", "two", "three", "four"}, + }, + { + name: "Negative", + input: map[string]int{ + "one": -1, + "two": 1, + "three": 2, + "four": 3, + }, + output: []string{"one", "two", "three", "four"}, + }, + { + name: "duplicate, expect bad output", + input: map[string]int{ + "one": 0, + "two": 1, + "three": 1, + "four": 2, + }, + output: []string{"one", "three", "three", "four"}, + }, + { + name: "Gap", + input: map[string]int{ + "one": 1, + "two": 2, + "three": 4, + "four": 5, + }, + output: []string{"one", "two", "three", "four"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.Equal(t, test.output, orderOptions(ctx, test.input)) + }) + } +} diff --git a/go.mod b/go.mod index cb655bd..c16c793 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gin-gonic/gin v1.11.0 github.com/sirupsen/logrus v1.9.3 github.com/slack-go/slack v0.17.3 + github.com/stretchr/testify v1.11.1 go.mongodb.org/mongo-driver v1.17.6 mvdan.cc/xurls/v2 v2.6.0 ) @@ -17,6 +18,7 @@ require ( github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/coreos/go-oidc v2.4.0+incompatible // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.11 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -36,6 +38,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/montanaflynn/stats v0.7.1 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pquerna/cachecontrol v0.2.0 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.55.0 // indirect @@ -57,4 +60,5 @@ require ( golang.org/x/tools v0.38.0 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ae41f3b..b325fdf 100644 --- a/go.sum +++ b/go.sum @@ -148,6 +148,7 @@ golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/go-jose/go-jose.v2 v2.6.3 h1:nt80fvSDlhKWQgSWyHyy5CfmlQr+asih51R8PTWNKKs= gopkg.in/go-jose/go-jose.v2 v2.6.3/go.mod h1:zzZDPkNNw/c9IE7Z9jr11mBZQhKQTMzoEEIoEdZlFBI= diff --git a/logging/logger.go b/logging/logger.go index 3956b17..f3d2d6c 100644 --- a/logging/logger.go +++ b/logging/logger.go @@ -3,19 +3,29 @@ package logging import ( "os" "runtime" + "testing" "github.com/sirupsen/logrus" ) -var Logger = &logrus.Logger{ - Out: os.Stdout, - Formatter: &logrus.TextFormatter{ - DisableLevelTruncation: true, - PadLevelText: true, - FullTimestamp: true, - }, - Hooks: make(logrus.LevelHooks), - Level: logrus.InfoLevel, +var Logger *logrus.Logger = makeLogger() + +func makeLogger() *logrus.Logger { + // TODO should this someday be configurable? + level := logrus.InfoLevel + if testing.Testing() { + level = logrus.DebugLevel + } + return &logrus.Logger{ + Out: os.Stdout, + Formatter: &logrus.TextFormatter{ + DisableLevelTruncation: true, + PadLevelText: true, + FullTimestamp: true, + }, + Hooks: make(logrus.LevelHooks), + Level: level, + } } func Trace() runtime.Frame {