diff --git a/db/cache/fastcache.go b/db/cache/fastcache.go index 9dff1b7..b6c8fb2 100644 --- a/db/cache/fastcache.go +++ b/db/cache/fastcache.go @@ -43,11 +43,13 @@ func (f *FastCache) Get(k []byte) ([]byte, error) { if b == nil { return nil, nil } - me := ttlCache{} err := binary.Unmarshal(b, &me) if err != nil { return nil, fmt.Errorf("FastCache.Get: %w", err) } + if time.Unix(me.TimeOut, 0).Before(time.Now()) { + return nil, nil + } return me.V, nil } diff --git a/service/yggdrasil/authenticate.go b/service/yggdrasil/authenticate.go index e565bc8..a669afc 100644 --- a/service/yggdrasil/authenticate.go +++ b/service/yggdrasil/authenticate.go @@ -25,7 +25,7 @@ var ( ) func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticate) (yggdrasil.Token, error) { - err := rate("Authenticate"+auth.Username, y.cache, 5*time.Second) + err := rate("Authenticate"+auth.Username, y.cache, 10*time.Second, 3) if err != nil { return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } @@ -103,22 +103,34 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat }, nil } -func rate(k string, c cache.Cache, d time.Duration) error { +func rate(k string, c cache.Cache, d time.Duration, count uint) error { key := []byte(k) v, err := c.Get([]byte(key)) if err != nil { return fmt.Errorf("rate: %w", err) } - if v != nil { - u := binary.BigEndian.Uint64(v) - t := time.Unix(int64(u), 0) - if time.Now().Before(t) { - return fmt.Errorf("rate: %w", ErrRate) + if v == nil { + err := putUint(1, c, key, d) + if err != nil { + return fmt.Errorf("rate: %w", err) } + return nil } - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(time.Now().Add(d).Unix())) - err = c.Put(key, b, time.Now().Add(d)) + n := binary.BigEndian.Uint64(v) + if n > uint64(count) { + return fmt.Errorf("rate: %w", ErrRate) + } + err = putUint(n+1, c, key, d) + if err != nil { + return fmt.Errorf("rate: %w", err) + } + return nil +} + +func putUint(n uint64, c cache.Cache, key []byte, d time.Duration) error { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, n) + err := c.Put(key, b, time.Now().Add(d)) if err != nil { return fmt.Errorf("rate: %w", err) }