diff --git a/discord/discord.go b/discord/discord.go index 5bc34a5..6b0ba6d 100644 --- a/discord/discord.go +++ b/discord/discord.go @@ -11,7 +11,15 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message return } - /* Checking if the message starts with the trigger specified in application struct + /* Check if the user is even allowed by the rate limiter */ + err := app.limiter.CheckAllowed(m.Author.ID) + if err != nil { + /* normally don't send, but now do for debug purposes. This is the admin bot channel */ + app.unknownError(err, s, true, "815952128106430514") + return + } + + /* Checking if the message starts with the trigger specified in application struct if it does then we start the switch statement to trigger the appropriate function if it does not then we check if it contains a triggerword from the database */ if strings.HasPrefix(m.Content, app.trigger) { @@ -52,14 +60,13 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message case "removeadmin": app.removeAdmin(s, m, splitCommand) } - + } } else { /* If the trigger wasn't the prefix of the message, we need to check all the words for a trigger */ app.findTrigger(s, m) } - - -} + app.limiter.LogInteraction(m.Author.ID, "messagecreate") +} diff --git a/discord/main.go b/discord/main.go index d3328df..fca10af 100644 --- a/discord/main.go +++ b/discord/main.go @@ -6,8 +6,10 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/bwmarrin/discordgo" + "quenten.nl/pepebot/limiter" "quenten.nl/pepebot/models/mysql" ) @@ -15,15 +17,16 @@ import ( It also has many methods for the different functions of the bot. These methods are mostly located in discord.go */ type application struct { - errorLog *log.Logger - infoLog *log.Logger - badwords *mysql.BadwordModel - adminroles *mysql.AdminRolesModel - trigger string + errorLog *log.Logger + infoLog *log.Logger + badwords *mysql.BadwordModel + adminroles *mysql.AdminRolesModel + trigger string allBadWords map[string][]string + limiter *limiter.Limiter active bool - stop bool + stop bool } func main() { @@ -39,12 +42,18 @@ func main() { errorLog.Fatal(err) } + limiter := &limiter.Limiter{ + RateLimit: 5, + TimeLimit: time.Second * 15, + } + app := &application{ - infoLog: infoLog, - errorLog: errorLog, - badwords: &mysql.BadwordModel{DB: db}, + infoLog: infoLog, + errorLog: errorLog, + badwords: &mysql.BadwordModel{DB: db}, adminroles: &mysql.AdminRolesModel{DB: db}, - trigger: "!pepe", + trigger: "!pepe", + limiter: limiter, } app.allBadWords, err = app.badwords.AllWords() diff --git a/limiter/action.go b/limiter/action.go new file mode 100644 index 0000000..920f92d --- /dev/null +++ b/limiter/action.go @@ -0,0 +1,8 @@ +package limiter + +import "time" + +type Action struct { + Type string + Timestamp time.Time +} diff --git a/limiter/limiter.go b/limiter/limiter.go new file mode 100644 index 0000000..0955cd5 --- /dev/null +++ b/limiter/limiter.go @@ -0,0 +1,55 @@ +package limiter + +import ( + "errors" + "time" +) + +/* The Limiter struct saves all interactions in a map of lists indexed by user id going back the time limit. +When checking if a user is allowed to perform an action, it traverses the list for that userid. +If an item is older than the time limit, remove it and don't count. +If it is in the limit, then count. If the amount of interactions is higher than the limit, return an error */ +type Limiter struct { + TimeLimit time.Duration + RateLimit int + Logs map[string][]Action +} + +func (l *Limiter) LogInteraction(userid string, action string) { + l.Logs[userid] = append(l.Logs[userid], Action{ + Timestamp: time.Now(), + Type: action, + }) +} + +/* CheckAllowed counts the amount of log entries for a given userid, making sure to delete and not count the expired ones. +Returns an error if the amount of log entries exceeds the ratelimit */ +func (l *Limiter) CheckAllowed(userid string) error { + counter := 0 + expiredEntries := make([]int, 0) + for i := 0; i < len(l.Logs[userid]); i++ { + /* If the timestamp plus the timelimit is happened before "Now" */ + if l.Logs[userid][i].Timestamp.Add(l.TimeLimit).Before(time.Now()) { + expiredEntries = append(expiredEntries, i) + continue + } else { + counter++ + continue + } + } + /* remove entries */ + for i := 0; i < len(expiredEntries); i++ { + l.removeAction(userid, expiredEntries[i]) + } + + if counter >= l.RateLimit { + return errors.New("rate limit exceeded") + } else { + return nil + } +} + +func (l *Limiter) removeAction(userid string, i int) { + l.Logs[userid][i] = l.Logs[userid][len(l.Logs[userid])-1] + l.Logs[userid] = l.Logs[userid][:len(l.Logs[userid])-1] +}