diff --git a/packages/auth/auth.go b/packages/auth/auth.go index e832a61..6955f08 100644 --- a/packages/auth/auth.go +++ b/packages/auth/auth.go @@ -56,13 +56,15 @@ type SignUpCredentials struct { Password string `json:"password", db:"password"` } -type SignInCredentials struct { +type UserCredentials struct { Username string `json:"username", db:"username"` Password string `json:"password", db:"password"` } type Claims struct { - Username string `json:"username"` + Username string `json:"username", db:"username"` + Admin string `json:"admin", db:"admin"` + Verified string `json:"verified", db:"verified"` jwt.StandardClaims } @@ -165,29 +167,25 @@ func register(w http.ResponseWriter, r *http.Request) { } func signin(w http.ResponseWriter, r *http.Request) { - creds := &SignInCredentials{} + creds := &UserCredentials{} err := json.NewDecoder(r.Body).Decode(creds) if err != nil { w.WriteHeader(http.StatusBadRequest) return } - result := DB.QueryRow("SELECT password FROM users WHERE username=$1", creds.Username) - storedCreds := &SignInCredentials{} - err = result.Scan(&storedCreds.Password) - if err != nil { - if err == sql.ErrNoRows { - w.WriteHeader(http.StatusUnauthorized) - return - } - w.WriteHeader(http.StatusInternalServerError) + verified, ok := checkPassword(w, *creds) + if !ok { + render.JSON(w, r, verified) return } - if err = bcrypt.CompareHashAndPassword([]byte(storedCreds.Password), []byte(creds.Password)); err != nil { - w.WriteHeader(http.StatusUnauthorized) - } expirationTime := time.Now().Add(24 * time.Hour) + user_claims := &Claims{} + user_claims_query := DB.QueryRow("SELECT username, admin, verified FROM users WHERE username=$1", creds.Username) + err = user_claims_query.Scan(&user_claims.Username, &user_claims.Admin, &user_claims.Verified) claims := &Claims{ - Username: creds.Username, + Username: user_claims.Username, + Admin: user_claims.Admin, + Verified: user_claims.Verified, StandardClaims: jwt.StandardClaims{ ExpiresAt: expirationTime.Unix(), }, @@ -199,11 +197,24 @@ func signin(w http.ResponseWriter, r *http.Request) { } func refresh(w http.ResponseWriter, r *http.Request) { + returnMessage := ReturnMessage{} _, claims, _ := jwtauth.FromContext(r.Context()) w.WriteHeader(http.StatusOK) expirationTime := time.Now().Add(5 * time.Hour) + user_claims := &Claims{} + user_claims_query := DB.QueryRow("SELECT username, admin, verified FROM users WHERE username=$1", claims["username"].(string)) + err := user_claims_query.Scan(&user_claims.Username, &user_claims.Admin, &user_claims.Verified) + if err != nil { + fmt.Println(err) + returnMessage.Message = "unexpected error refreshing your token, please try again later" + w.WriteHeader(http.StatusInternalServerError) + render.JSON(w, r, returnMessage) + return + } newClaims := &Claims{ - Username: claims["username"].(string), + Username: user_claims.Username, + Admin: user_claims.Admin, + Verified: user_claims.Verified, StandardClaims: jwt.StandardClaims{ ExpiresAt: expirationTime.Unix(), }, @@ -263,3 +274,27 @@ func sendEmailToken(w http.ResponseWriter, token string, name string, email stri ok = true return returnMessage, ok } + +func checkPassword(w http.ResponseWriter, creds UserCredentials) (string, bool) { + returnMessage := "" + result := DB.QueryRow("SELECT password FROM users WHERE username=$1", creds.Username) + storedCreds := &UserCredentials{} + err := result.Scan(&storedCreds.Password) + if err != nil { + if err == sql.ErrNoRows { + w.WriteHeader(http.StatusUnauthorized) + returnMessage = "username or password incorrect" + return returnMessage, false + } + w.WriteHeader(http.StatusInternalServerError) + returnMessage = "something is broken, please contact the administrator" + return returnMessage, false + } + if err = bcrypt.CompareHashAndPassword([]byte(storedCreds.Password), []byte(creds.Password)); err != nil { + w.WriteHeader(http.StatusUnauthorized) + returnMessage = "username or password incorrect" + return returnMessage, false + } + returnMessage = "user logged in" + return returnMessage, true +}