Added rate limiter

This commit is contained in:
DutchEllie 2021-05-31 10:58:16 +02:00
parent 394cb914b5
commit 2dc44348a4
4 changed files with 94 additions and 15 deletions

View File

@ -11,7 +11,15 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message
return
}
/* Checking if the message starts with the trigger specified in application struct
/* Check if the user is even allowed by the rate limiter */
err := app.limiter.CheckAllowed(m.Author.ID)
if err != nil {
/* normally don't send, but now do for debug purposes. This is the admin bot channel */
app.unknownError(err, s, true, "815952128106430514")
return
}
/* Checking if the message starts with the trigger specified in application struct
if it does then we start the switch statement to trigger the appropriate function
if it does not then we check if it contains a triggerword from the database */
if strings.HasPrefix(m.Content, app.trigger) {
@ -52,14 +60,13 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message
case "removeadmin":
app.removeAdmin(s, m, splitCommand)
}
}
} else {
/* If the trigger wasn't the prefix of the message, we need to check all the words for a trigger */
app.findTrigger(s, m)
}
}
app.limiter.LogInteraction(m.Author.ID, "messagecreate")
}

View File

@ -6,8 +6,10 @@ import (
"os"
"os/signal"
"syscall"
"time"
"github.com/bwmarrin/discordgo"
"quenten.nl/pepebot/limiter"
"quenten.nl/pepebot/models/mysql"
)
@ -15,15 +17,16 @@ import (
It also has many methods for the different functions of the bot.
These methods are mostly located in discord.go */
type application struct {
errorLog *log.Logger
infoLog *log.Logger
badwords *mysql.BadwordModel
adminroles *mysql.AdminRolesModel
trigger string
errorLog *log.Logger
infoLog *log.Logger
badwords *mysql.BadwordModel
adminroles *mysql.AdminRolesModel
trigger string
allBadWords map[string][]string
limiter *limiter.Limiter
active bool
stop bool
stop bool
}
func main() {
@ -39,12 +42,18 @@ func main() {
errorLog.Fatal(err)
}
limiter := &limiter.Limiter{
RateLimit: 5,
TimeLimit: time.Second * 15,
}
app := &application{
infoLog: infoLog,
errorLog: errorLog,
badwords: &mysql.BadwordModel{DB: db},
infoLog: infoLog,
errorLog: errorLog,
badwords: &mysql.BadwordModel{DB: db},
adminroles: &mysql.AdminRolesModel{DB: db},
trigger: "!pepe",
trigger: "!pepe",
limiter: limiter,
}
app.allBadWords, err = app.badwords.AllWords()

8
limiter/action.go Normal file
View File

@ -0,0 +1,8 @@
package limiter
import "time"
type Action struct {
Type string
Timestamp time.Time
}

55
limiter/limiter.go Normal file
View File

@ -0,0 +1,55 @@
package limiter
import (
"errors"
"time"
)
/* The Limiter struct saves all interactions in a map of lists indexed by user id going back the time limit.
When checking if a user is allowed to perform an action, it traverses the list for that userid.
If an item is older than the time limit, remove it and don't count.
If it is in the limit, then count. If the amount of interactions is higher than the limit, return an error */
type Limiter struct {
TimeLimit time.Duration
RateLimit int
Logs map[string][]Action
}
func (l *Limiter) LogInteraction(userid string, action string) {
l.Logs[userid] = append(l.Logs[userid], Action{
Timestamp: time.Now(),
Type: action,
})
}
/* CheckAllowed counts the amount of log entries for a given userid, making sure to delete and not count the expired ones.
Returns an error if the amount of log entries exceeds the ratelimit */
func (l *Limiter) CheckAllowed(userid string) error {
counter := 0
expiredEntries := make([]int, 0)
for i := 0; i < len(l.Logs[userid]); i++ {
/* If the timestamp plus the timelimit is happened before "Now" */
if l.Logs[userid][i].Timestamp.Add(l.TimeLimit).Before(time.Now()) {
expiredEntries = append(expiredEntries, i)
continue
} else {
counter++
continue
}
}
/* remove entries */
for i := 0; i < len(expiredEntries); i++ {
l.removeAction(userid, expiredEntries[i])
}
if counter >= l.RateLimit {
return errors.New("rate limit exceeded")
} else {
return nil
}
}
func (l *Limiter) removeAction(userid string, i int) {
l.Logs[userid][i] = l.Logs[userid][len(l.Logs[userid])-1]
l.Logs[userid] = l.Logs[userid][:len(l.Logs[userid])-1]
}