Added rate limiter
This commit is contained in:
		
							parent
							
								
									394cb914b5
								
							
						
					
					
						commit
						2dc44348a4
					
				| @ -11,7 +11,15 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	/* Checking if the message starts with the trigger specified in application struct  | ||||
| 	/* Check if the user is even allowed by the rate limiter */ | ||||
| 	err := app.limiter.CheckAllowed(m.Author.ID) | ||||
| 	if err != nil { | ||||
| 		/* normally don't send, but now do for debug purposes. This is the admin bot channel */ | ||||
| 		app.unknownError(err, s, true, "815952128106430514") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	/* Checking if the message starts with the trigger specified in application struct | ||||
| 	if it does then we start the switch statement to trigger the appropriate function | ||||
| 	if it does not then we check if it contains a triggerword from the database */ | ||||
| 	if strings.HasPrefix(m.Content, app.trigger) { | ||||
| @ -52,14 +60,13 @@ func (app *application) messageCreate(s *discordgo.Session, m *discordgo.Message | ||||
| 			case "removeadmin": | ||||
| 				app.removeAdmin(s, m, splitCommand) | ||||
| 			} | ||||
| 			 | ||||
| 
 | ||||
| 		} | ||||
| 	} else { | ||||
| 		/* If the trigger wasn't the prefix of the message, we need to check all the words for a trigger */ | ||||
| 		app.findTrigger(s, m) | ||||
| 	} | ||||
| 
 | ||||
| 	 | ||||
| 	 | ||||
| } | ||||
| 	app.limiter.LogInteraction(m.Author.ID, "messagecreate") | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -6,8 +6,10 @@ import ( | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/bwmarrin/discordgo" | ||||
| 	"quenten.nl/pepebot/limiter" | ||||
| 	"quenten.nl/pepebot/models/mysql" | ||||
| ) | ||||
| 
 | ||||
| @ -15,15 +17,16 @@ import ( | ||||
| It also has many methods for the different functions of the bot. | ||||
| These methods are mostly located in discord.go */ | ||||
| type application struct { | ||||
| 	errorLog *log.Logger | ||||
| 	infoLog  *log.Logger | ||||
| 	badwords *mysql.BadwordModel | ||||
| 	adminroles *mysql.AdminRolesModel | ||||
| 	trigger string | ||||
| 	errorLog    *log.Logger | ||||
| 	infoLog     *log.Logger | ||||
| 	badwords    *mysql.BadwordModel | ||||
| 	adminroles  *mysql.AdminRolesModel | ||||
| 	trigger     string | ||||
| 	allBadWords map[string][]string | ||||
| 	limiter     *limiter.Limiter | ||||
| 
 | ||||
| 	active bool | ||||
| 	stop bool | ||||
| 	stop   bool | ||||
| } | ||||
| 
 | ||||
| func main() { | ||||
| @ -39,12 +42,18 @@ func main() { | ||||
| 		errorLog.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	limiter := &limiter.Limiter{ | ||||
| 		RateLimit: 5, | ||||
| 		TimeLimit: time.Second * 15, | ||||
| 	} | ||||
| 
 | ||||
| 	app := &application{ | ||||
| 		infoLog: infoLog, | ||||
| 		errorLog: errorLog, | ||||
| 		badwords: &mysql.BadwordModel{DB: db}, | ||||
| 		infoLog:    infoLog, | ||||
| 		errorLog:   errorLog, | ||||
| 		badwords:   &mysql.BadwordModel{DB: db}, | ||||
| 		adminroles: &mysql.AdminRolesModel{DB: db}, | ||||
| 		trigger: "!pepe", | ||||
| 		trigger:    "!pepe", | ||||
| 		limiter:    limiter, | ||||
| 	} | ||||
| 
 | ||||
| 	app.allBadWords, err = app.badwords.AllWords() | ||||
|  | ||||
							
								
								
									
										8
									
								
								limiter/action.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								limiter/action.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| package limiter | ||||
| 
 | ||||
| import "time" | ||||
| 
 | ||||
| type Action struct { | ||||
| 	Type      string | ||||
| 	Timestamp time.Time | ||||
| } | ||||
							
								
								
									
										55
									
								
								limiter/limiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								limiter/limiter.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,55 @@ | ||||
| package limiter | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| /* The Limiter struct saves all interactions in a map of lists indexed by user id going back the time limit. | ||||
| When checking if a user is allowed to perform an action, it traverses the list for that userid. | ||||
| If an item is older than the time limit, remove it and don't count. | ||||
| If it is in the limit, then count. If the amount of interactions is higher than the limit, return an error */ | ||||
| type Limiter struct { | ||||
| 	TimeLimit time.Duration | ||||
| 	RateLimit int | ||||
| 	Logs      map[string][]Action | ||||
| } | ||||
| 
 | ||||
| func (l *Limiter) LogInteraction(userid string, action string) { | ||||
| 	l.Logs[userid] = append(l.Logs[userid], Action{ | ||||
| 		Timestamp: time.Now(), | ||||
| 		Type:      action, | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| /* CheckAllowed counts the amount of log entries for a given userid, making sure to delete and not count the expired ones. | ||||
| Returns an error if the amount of log entries exceeds the ratelimit */ | ||||
| func (l *Limiter) CheckAllowed(userid string) error { | ||||
| 	counter := 0 | ||||
| 	expiredEntries := make([]int, 0) | ||||
| 	for i := 0; i < len(l.Logs[userid]); i++ { | ||||
| 		/* If the timestamp plus the timelimit is happened before "Now" */ | ||||
| 		if l.Logs[userid][i].Timestamp.Add(l.TimeLimit).Before(time.Now()) { | ||||
| 			expiredEntries = append(expiredEntries, i) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			counter++ | ||||
| 			continue | ||||
| 		} | ||||
| 	} | ||||
| 	/* remove entries */ | ||||
| 	for i := 0; i < len(expiredEntries); i++ { | ||||
| 		l.removeAction(userid, expiredEntries[i]) | ||||
| 	} | ||||
| 
 | ||||
| 	if counter >= l.RateLimit { | ||||
| 		return errors.New("rate limit exceeded") | ||||
| 	} else { | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (l *Limiter) removeAction(userid string, i int) { | ||||
| 	l.Logs[userid][i] = l.Logs[userid][len(l.Logs[userid])-1] | ||||
| 	l.Logs[userid] = l.Logs[userid][:len(l.Logs[userid])-1] | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user