Compare commits
2 Commits
e2d8bd344d
...
important-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b1293bb6f | ||
|
|
d9de02f08d |
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"ResendIt/internal/api/middleware"
|
||||
"ResendIt/internal/auth"
|
||||
"ResendIt/internal/db"
|
||||
"ResendIt/internal/file"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/joho/godotenv"
|
||||
@@ -74,6 +76,9 @@ func main() {
|
||||
createAdminUser(userService)
|
||||
|
||||
apiRoute := r.Group("/api")
|
||||
// General API rate limiting to reduce abuse/spam.
|
||||
// ~60 req/min per IP with some burst room.
|
||||
apiRoute.Use(middleware.RateLimitByIP(60, time.Minute, 30, 5*time.Minute))
|
||||
|
||||
auth.RegisterRoutes(apiRoute, authHandler)
|
||||
user.RegisterRoutes(apiRoute, userHandler)
|
||||
|
||||
118
internal/api/middleware/ratelimit.go
Normal file
118
internal/api/middleware/ratelimit.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type tokenBucket struct {
|
||||
mu sync.Mutex
|
||||
rate float64 // tokens per second
|
||||
burst float64 // max tokens
|
||||
tokens float64
|
||||
last time.Time
|
||||
}
|
||||
|
||||
func newTokenBucket(max int, per time.Duration, burst int) *tokenBucket {
|
||||
if burst <= 0 {
|
||||
burst = max
|
||||
}
|
||||
rate := float64(max) / per.Seconds()
|
||||
b := float64(burst)
|
||||
now := time.Now()
|
||||
return &tokenBucket{rate: rate, burst: b, tokens: b, last: now}
|
||||
}
|
||||
|
||||
func (b *tokenBucket) allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
delta := now.Sub(b.last).Seconds()
|
||||
b.last = now
|
||||
|
||||
b.tokens += delta * b.rate
|
||||
if b.tokens > b.burst {
|
||||
b.tokens = b.burst
|
||||
}
|
||||
|
||||
if b.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
b.tokens -= 1
|
||||
return true
|
||||
}
|
||||
|
||||
type ipClient struct {
|
||||
bucket *tokenBucket
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// RateLimitByIP returns a Gin middleware that rate limits requests per client IP.
|
||||
//
|
||||
// max: max requests per time window (per)
|
||||
// per: the time window duration
|
||||
// burst: optional burst capacity (defaults to max if <=0)
|
||||
// ttl: how long to keep idle IP buckets around
|
||||
func RateLimitByIP(max int, per time.Duration, burst int, ttl time.Duration) gin.HandlerFunc {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
clients = make(map[string]*ipClient)
|
||||
)
|
||||
|
||||
// opportunistic cleanup (runs at most once per minute)
|
||||
var (
|
||||
cleanupMu sync.Mutex
|
||||
lastCleanup time.Time
|
||||
)
|
||||
|
||||
cleanup := func(now time.Time) {
|
||||
cleanupMu.Lock()
|
||||
defer cleanupMu.Unlock()
|
||||
if !lastCleanup.IsZero() && now.Sub(lastCleanup) < time.Minute {
|
||||
return
|
||||
}
|
||||
lastCleanup = now
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for ip, c := range clients {
|
||||
if now.Sub(c.lastSeen) > ttl {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getClient := func(ip string, now time.Time) *ipClient {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
c, ok := clients[ip]
|
||||
if !ok {
|
||||
c = &ipClient{bucket: newTokenBucket(max, per, burst), lastSeen: now}
|
||||
clients[ip] = c
|
||||
return c
|
||||
}
|
||||
c.lastSeen = now
|
||||
return c
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
now := time.Now()
|
||||
cleanup(now)
|
||||
|
||||
ip := c.ClientIP()
|
||||
client := getClient(ip, now)
|
||||
|
||||
if !client.bucket.allow() {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
64
internal/api/middleware/ratelimit_test.go
Normal file
64
internal/api/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestRateLimitByIP_BlocksAfterLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
// 1 request per hour, burst 1 => second immediate request should 429.
|
||||
r.Use(RateLimitByIP(1, time.Hour, 1, time.Minute))
|
||||
r.GET("/", func(c *gin.Context) { c.String(200, "ok") })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "203.0.113.10:1234"
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, req)
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first request code = %d, want %d", w1.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req)
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second request code = %d, want %d", w2.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitByIP_AllowsBurst(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
// 1 per hour, but burst 2 => first two immediate requests should pass.
|
||||
r.Use(RateLimitByIP(1, time.Hour, 2, time.Minute))
|
||||
r.GET("/", func(c *gin.Context) { c.String(200, "ok") })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "203.0.113.11:1234"
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, req)
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first request code = %d, want %d", w1.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req)
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("second request code = %d, want %d", w2.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
w3 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w3, req)
|
||||
if w3.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("third request code = %d, want %d", w3.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"ResendIt/internal/api/middleware"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -9,7 +10,9 @@ import (
|
||||
func RegisterRoutes(r *gin.RouterGroup, h *Handler) {
|
||||
auth := r.Group("/auth")
|
||||
|
||||
auth.POST("/login", h.Login)
|
||||
// Stricter rate limit on login to reduce brute-force / log spam.
|
||||
// 5 attempts per minute per IP, burst 10.
|
||||
auth.POST("/login", middleware.RateLimitByIP(5, time.Minute, 10, 15*time.Minute), h.Login)
|
||||
|
||||
protected := auth.Group("/")
|
||||
protected.Use(middleware.AuthMiddleware())
|
||||
|
||||
82
internal/auth/service_test.go
Normal file
82
internal/auth/service_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"ResendIt/internal/security"
|
||||
"ResendIt/internal/user"
|
||||
"testing"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestServiceLogin_InvalidUserDoesNotEnumerate(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&user.User{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
svc := NewService(NewRepository(db))
|
||||
|
||||
_, err = svc.Login("does-not-exist", "whatever")
|
||||
if err != ErrInvalidCredentials {
|
||||
t.Fatalf("expected ErrInvalidCredentials for missing user, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceLogin_WrongPassword(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&user.User{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
hash, err := security.HashPassword("right")
|
||||
if err != nil {
|
||||
t.Fatalf("hash: %v", err)
|
||||
}
|
||||
|
||||
u := user.User{Username: "alice", PasswordHash: hash, Role: "user"}
|
||||
if err := db.Create(&u).Error; err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
svc := NewService(NewRepository(db))
|
||||
_, err = svc.Login("alice", "wrong")
|
||||
if err != ErrInvalidCredentials {
|
||||
t.Fatalf("expected ErrInvalidCredentials for wrong password, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceLogin_SuccessReturnsJWT(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&user.User{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
hash, err := security.HashPassword("right")
|
||||
if err != nil {
|
||||
t.Fatalf("hash: %v", err)
|
||||
}
|
||||
|
||||
u := user.User{Username: "alice", PasswordHash: hash, Role: "user"}
|
||||
if err := db.Create(&u).Error; err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
|
||||
svc := NewService(NewRepository(db))
|
||||
token, err := svc.Login("alice", "right")
|
||||
if err != nil {
|
||||
t.Fatalf("expected success, got error: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatalf("expected non-empty jwt token")
|
||||
}
|
||||
}
|
||||
22
internal/security/password_test.go
Normal file
22
internal/security/password_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package security
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHashAndCheckPassword(t *testing.T) {
|
||||
pw := "correct horse battery staple"
|
||||
|
||||
hash, err := HashPassword(pw)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword returned error: %v", err)
|
||||
}
|
||||
if hash == "" {
|
||||
t.Fatalf("expected non-empty hash")
|
||||
}
|
||||
|
||||
if !CheckPassword(pw, hash) {
|
||||
t.Fatalf("expected CheckPassword to succeed for correct password")
|
||||
}
|
||||
if CheckPassword("wrong", hash) {
|
||||
t.Fatalf("expected CheckPassword to fail for wrong password")
|
||||
}
|
||||
}
|
||||
40
internal/util/util_test.go
Normal file
40
internal/util/util_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHumanSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
in int64
|
||||
want string
|
||||
}{
|
||||
{0, "0 B"},
|
||||
{1, "1 B"},
|
||||
{1023, "1023 B"},
|
||||
{1024, "1.0 KB"},
|
||||
{1024 * 1024, "1.0 MB"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := HumanSize(tt.in); got != tt.want {
|
||||
t.Fatalf("HumanSize(%d) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeFilename(t *testing.T) {
|
||||
if got := SafeFilename(" hello.txt "); got != "hello.txt" {
|
||||
t.Fatalf("expected trimmed filename, got %q", got)
|
||||
}
|
||||
|
||||
// Strips control characters, quotes, and backslashes.
|
||||
in := "a\n\rb\t\"c\\d"
|
||||
got := SafeFilename(in)
|
||||
if got != "abcd" {
|
||||
t.Fatalf("SafeFilename(%q) = %q, want %q", in, got, "abcd")
|
||||
}
|
||||
|
||||
// Empty after sanitization becomes default.
|
||||
if got := SafeFilename("\n\r\t"); got != "file" {
|
||||
t.Fatalf("expected default filename, got %q", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user