Base stuff from other project

This commit is contained in:
Amarpreet Minhas 2019-02-03 01:57:08 -05:00
parent b7132ce2dc
commit 6bb4b1246f
5 changed files with 177 additions and 1 deletions

View file

@ -1,3 +1,6 @@
# sudoscientist
Sudo Scientist blog
API_PORT=8080 DBHOST="postgres.localhost" DBPORT="5432" DBUSER="asara" DBPW="PW" DBNAME="sudoscientist" \
go run main.go

63
main.go Normal file
View file

@ -0,0 +1,63 @@
package main
import (
"fmt"
_ "github.com/lib/pq"
"log"
"net/http"
"os"
"git.minhas.io/asara/sudoscientist/packages/auth"
"git.minhas.io/asara/sudoscientist/packages/database"
"github.com/dgrijalva/jwt-go"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/jwtauth"
"github.com/go-chi/render"
)
func main() {
// initiate the database
db, _ := database.NewDB()
defer db.Close()
auth.DB = db
auth.Init()
// initiate jwt token
auth.TokenAuth = jwtauth.New("HS256", []byte("secret"), nil)
_, tokenString, _ := auth.TokenAuth.Encode(jwt.MapClaims{"asara": 123})
log.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
// initiate the routes
router := Routes()
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
fmt.Printf("%s %s\n", method, route)
return nil
}
if err := chi.Walk(router, walkFunc); err != nil {
log.Panicf("Logging err: %s\n", err.Error())
}
// start server
apiPort := fmt.Sprintf(":%s", os.Getenv("API_PORT"))
log.Fatal(http.ListenAndServe(apiPort, router))
}
func Routes() *chi.Mux {
router := chi.NewRouter()
router.Use(
render.SetContentType(render.ContentTypeJSON),
middleware.Logger,
middleware.DefaultCompress,
middleware.RedirectSlashes,
middleware.Recoverer,
)
router.Route("/v1", func(r chi.Router) {
r.Mount("/api/auth", auth.Routes())
})
return router
}

84
packages/auth/auth.go Normal file
View file

@ -0,0 +1,84 @@
package auth
import (
"fmt"
"database/sql"
"encoding/json"
"github.com/dgrijalva/jwt-go"
"github.com/go-chi/chi"
"github.com/go-chi/jwtauth"
"github.com/go-chi/render"
"golang.org/x/crypto/bcrypt"
"net/http"
)
var (
DB *sql.DB
TokenAuth *jwtauth.JWTAuth
)
type Credentials struct {
Password string `json:"password", db:"password"`
Username string `json:"username", db:"username"`
}
func Init() {
DB.Exec("CREATE TABLE IF NOT EXISTS users (username text primary key, password text, admin boolean);" )
}
func Routes() *chi.Mux {
router := chi.NewRouter()
router.Post("/signin", Signin)
router.Post("/signup", Signup)
return router
}
func Signup(w http.ResponseWriter, r *http.Request) {
creds := &Credentials{}
err := json.NewDecoder(r.Body).Decode(creds)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(creds.Password), 10)
s := `INSERT INTO users (username, password, admin)
VALUES ($1, $2, $3)`
if _, err = DB.Exec(s, creds.Username, string(hashedPassword), false); err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Println(err)
return
}
w.WriteHeader(http.StatusCreated)
}
func Signin(w http.ResponseWriter, r *http.Request) {
creds := &Credentials{}
err := json.NewDecoder(r.Body).Decode(creds)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
result := DB.QueryRow("SELECT password FROM users WHERE username=$1", creds.Username)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
storedCreds := &Credentials{}
err = result.Scan(&storedCreds.Password)
if err != nil {
if err == sql.ErrNoRows {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.WriteHeader(http.StatusInternalServerError)
return
}
if err = bcrypt.CompareHashAndPassword([]byte(storedCreds.Password), []byte(creds.Password)); err != nil {
w.WriteHeader(http.StatusUnauthorized)
}
_, tokenString, _ := TokenAuth.Encode(jwt.MapClaims{
"username": creds.Username,
})
w.WriteHeader(http.StatusOK)
render.JSON(w, r, tokenString)
}

6
packages/auth/users.sql Normal file
View file

@ -0,0 +1,6 @@
create table users (
username text primary key,
password text,
admin boolean
);

View file

@ -0,0 +1,20 @@
package database
import (
"fmt"
"os"
"database/sql"
)
func NewDB() (*sql.DB, error) {
psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", os.Getenv("DBHOST"), os.Getenv("DBPORT"), os.Getenv("DBUSER"), os.Getenv("DBPW"), os.Getenv("DBNAME"))
db, err := sql.Open("postgres", psqlInfo)
if err != nil {
panic(err)
}
err = db.Ping()
if err != nil {
panic(err)
}
fmt.Println("Connected to database")
return db, nil
}