Added rate limiter
This commit is contained in:
parent
394cb914b5
commit
2dc44348a4
|
@ -11,7 +11,15 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message
|
|||
return
|
||||
}
|
||||
|
||||
/* Checking if the message starts with the trigger specified in application struct
|
||||
/* Check if the user is even allowed by the rate limiter */
|
||||
err := app.limiter.CheckAllowed(m.Author.ID)
|
||||
if err != nil {
|
||||
/* normally don't send, but now do for debug purposes. This is the admin bot channel */
|
||||
app.unknownError(err, s, true, "815952128106430514")
|
||||
return
|
||||
}
|
||||
|
||||
/* Checking if the message starts with the trigger specified in application struct
|
||||
if it does then we start the switch statement to trigger the appropriate function
|
||||
if it does not then we check if it contains a triggerword from the database */
|
||||
if strings.HasPrefix(m.Content, app.trigger) {
|
||||
|
@ -52,14 +60,13 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message
|
|||
case "removeadmin":
|
||||
app.removeAdmin(s, m, splitCommand)
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
} else {
|
||||
/* If the trigger wasn't the prefix of the message, we need to check all the words for a trigger */
|
||||
app.findTrigger(s, m)
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
app.limiter.LogInteraction(m.Author.ID, "messagecreate")
|
||||
|
||||
}
|
||||
|
|
|
@ -6,8 +6,10 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"quenten.nl/pepebot/limiter"
|
||||
"quenten.nl/pepebot/models/mysql"
|
||||
)
|
||||
|
||||
|
@ -15,15 +17,16 @@ import (
|
|||
It also has many methods for the different functions of the bot.
|
||||
These methods are mostly located in discord.go */
|
||||
type application struct {
|
||||
errorLog *log.Logger
|
||||
infoLog *log.Logger
|
||||
badwords *mysql.BadwordModel
|
||||
adminroles *mysql.AdminRolesModel
|
||||
trigger string
|
||||
errorLog *log.Logger
|
||||
infoLog *log.Logger
|
||||
badwords *mysql.BadwordModel
|
||||
adminroles *mysql.AdminRolesModel
|
||||
trigger string
|
||||
allBadWords map[string][]string
|
||||
limiter *limiter.Limiter
|
||||
|
||||
active bool
|
||||
stop bool
|
||||
stop bool
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -39,12 +42,18 @@ func main() {
|
|||
errorLog.Fatal(err)
|
||||
}
|
||||
|
||||
limiter := &limiter.Limiter{
|
||||
RateLimit: 5,
|
||||
TimeLimit: time.Second * 15,
|
||||
}
|
||||
|
||||
app := &application{
|
||||
infoLog: infoLog,
|
||||
errorLog: errorLog,
|
||||
badwords: &mysql.BadwordModel{DB: db},
|
||||
infoLog: infoLog,
|
||||
errorLog: errorLog,
|
||||
badwords: &mysql.BadwordModel{DB: db},
|
||||
adminroles: &mysql.AdminRolesModel{DB: db},
|
||||
trigger: "!pepe",
|
||||
trigger: "!pepe",
|
||||
limiter: limiter,
|
||||
}
|
||||
|
||||
app.allBadWords, err = app.badwords.AllWords()
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
package limiter
|
||||
|
||||
import "time"
|
||||
|
||||
type Action struct {
|
||||
Type string
|
||||
Timestamp time.Time
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
/* The Limiter struct saves all interactions in a map of lists indexed by user id going back the time limit.
|
||||
When checking if a user is allowed to perform an action, it traverses the list for that userid.
|
||||
If an item is older than the time limit, remove it and don't count.
|
||||
If it is in the limit, then count. If the amount of interactions is higher than the limit, return an error */
|
||||
type Limiter struct {
|
||||
TimeLimit time.Duration
|
||||
RateLimit int
|
||||
Logs map[string][]Action
|
||||
}
|
||||
|
||||
func (l *Limiter) LogInteraction(userid string, action string) {
|
||||
l.Logs[userid] = append(l.Logs[userid], Action{
|
||||
Timestamp: time.Now(),
|
||||
Type: action,
|
||||
})
|
||||
}
|
||||
|
||||
/* CheckAllowed counts the amount of log entries for a given userid, making sure to delete and not count the expired ones.
|
||||
Returns an error if the amount of log entries exceeds the ratelimit */
|
||||
func (l *Limiter) CheckAllowed(userid string) error {
|
||||
counter := 0
|
||||
expiredEntries := make([]int, 0)
|
||||
for i := 0; i < len(l.Logs[userid]); i++ {
|
||||
/* If the timestamp plus the timelimit is happened before "Now" */
|
||||
if l.Logs[userid][i].Timestamp.Add(l.TimeLimit).Before(time.Now()) {
|
||||
expiredEntries = append(expiredEntries, i)
|
||||
continue
|
||||
} else {
|
||||
counter++
|
||||
continue
|
||||
}
|
||||
}
|
||||
/* remove entries */
|
||||
for i := 0; i < len(expiredEntries); i++ {
|
||||
l.removeAction(userid, expiredEntries[i])
|
||||
}
|
||||
|
||||
if counter >= l.RateLimit {
|
||||
return errors.New("rate limit exceeded")
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) removeAction(userid string, i int) {
|
||||
l.Logs[userid][i] = l.Logs[userid][len(l.Logs[userid])-1]
|
||||
l.Logs[userid] = l.Logs[userid][:len(l.Logs[userid])-1]
|
||||
}
|
Loading…
Reference in New Issue