it works!
This commit is contained in:
@@ -0,0 +1,256 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
const SessionCookie = "session"
|
||||
|
||||
type OAuthStore struct {
|
||||
sessions map[string]*Session
|
||||
mutex sync.RWMutex
|
||||
oa2 *oauth2.Config
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore {
|
||||
return &OAuthStore{
|
||||
sessions: make(map[string]*Session),
|
||||
oa2: oa2,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) {
|
||||
uid, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := uid.String()
|
||||
|
||||
session := &Session{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(duration),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.sessions[id] = session
|
||||
s.mutex.Unlock()
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
s.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
|
||||
}
|
||||
|
||||
func (s *OAuthStore) DeleteSession(sessionID string) {
|
||||
s.mutex.Lock()
|
||||
delete(s.sessions, sessionID)
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
func sendToLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect)
|
||||
}
|
||||
func generateRandomToken() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
//go:embed templates/LoginPage.html
|
||||
var loginPageContent string
|
||||
|
||||
func (s *OAuthStore) LoginPage() http.Handler {
|
||||
|
||||
loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent))
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
state := generateRandomToken()
|
||||
|
||||
if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" {
|
||||
state = cookie.Value
|
||||
} else {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state,
|
||||
HttpOnly: true,
|
||||
Secure: true, // Set to true in production
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 60 * 10, // 10 minutes
|
||||
})
|
||||
}
|
||||
|
||||
url := s.oa2.AuthCodeURL(state)
|
||||
|
||||
provider := s.config.OAuthProvider.Kind
|
||||
switch provider {
|
||||
case "google":
|
||||
provider = "Google"
|
||||
case "github":
|
||||
provider = "GitHub"
|
||||
}
|
||||
|
||||
loginPageTemplate.Execute(w, struct {
|
||||
Url string
|
||||
State string
|
||||
Provider string
|
||||
}{
|
||||
Url: url,
|
||||
State: state,
|
||||
Provider: provider,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OAuthStore) Protected(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
cookie, err := r.Cookie(SessionCookie)
|
||||
if err != nil {
|
||||
sendToLoginPage(w, r)
|
||||
return
|
||||
}
|
||||
sess, exists := s.GetSession(cookie.Value)
|
||||
if !exists {
|
||||
sendToLoginPage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, user := range s.config.AllowedUsers {
|
||||
if user == sess.UserID {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
sendToLoginPage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OAuthStore) CallbackHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
stateCookie, err := r.Cookie("oauth_state")
|
||||
if err != nil || stateCookie.Value != r.URL.Query().Get("state") {
|
||||
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get info", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
println(userID)
|
||||
sess, err := s.CreateSession(userID, time.Hour*24)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookie,
|
||||
Value: sess.ID,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(time.Hour.Seconds() * 24),
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
// clear cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
})
|
||||
|
||||
// TODO: remember what path the user was on and redirect them back there after doing the whole login process
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
})
|
||||
}
|
||||
|
||||
func getUserInfo(providerKind, token string) (string, error) {
|
||||
switch providerKind {
|
||||
case "google":
|
||||
type UserInfo struct {
|
||||
ID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userInfo.Email, nil
|
||||
case "github":
|
||||
type UserInfo struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
req.Header.Add("Authorization", "Bearer "+token)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userInfo.Login, nil
|
||||
default:
|
||||
panic("unimplemented")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user