summaryrefslogtreecommitdiff
path: root/middleware.go
blob: 33aeae0e751039edea31902a10e0b6fad356066e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package main

import (
	"context"
	"encoding/base64"
	"fmt"
	"log"
	"net/http"
	"strings"
	"time"
)

type User struct {
	UserId      string
	UserName    string
	DisplayName string
	Email       string
	Active      bool
	Staff       bool
}

func GetUser(req *http.Request) (*User, error) {
	if user_header, ok := req.Header[pwman.RemoteUserHeader]; ok {
		// If mre than one header abort
		if len(user_header) != 1 {
			return nil, fmt.Errorf("Expected one user, but got multiple")
		}
		// Got user lets go
		userid := user_header[0]
		//utf8 decode?
		first_name := first(req.Header["Givenname"])
		last_name := first(req.Header["Sn"])
		email := first(req.Header["Mail"])
		affiliations := req.Header["Affiliation"]
		is_staff := contains(affiliations, "employee@nordu.net")
		is_active := is_staff || contains(affiliations, "member@nordu.net")
		username := strings.Split(userid, "@")[0]

		return &User{
			userid,
			username,
			fmt.Sprintf("%v %v", first_name, last_name),
			email,
			is_active,
			is_staff}, nil
	}
	return nil, fmt.Errorf("No user found")
}

func RemoteUser(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {

		user, err := GetUser(req)
		if err != nil {
			log.Println("ERROR:", err)
			http.Error(w, "Please log in", http.StatusUnauthorized)
			return
		}
		// consider redirect to login with next

		ctx := req.Context()
		ctx = context.WithValue(ctx, "user", user)

		next.ServeHTTP(w, req.WithContext(ctx))
	})
}

func FlashMessage(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		clear := &http.Cookie{Name: "flashmsg", MaxAge: -1, Expires: time.Unix(1, 0)}
		// Get flash from cookie
		cookie, err := req.Cookie("flashmsg")
		if err != nil {
			next.ServeHTTP(w, req)
			return
		}

		msgB, err := base64.URLEncoding.DecodeString(cookie.Value)
		if err != nil {
			//unset flash message
			http.SetCookie(w, clear)
			next.ServeHTTP(w, req)
			return
		}

		msg := string(msgB)
		msg_parts := strings.Split(msg, ";_;")
		flash_class := "info"
		if len(msg_parts) == 2 {
			if msg_parts[1] != "" {
				flash_class = msg_parts[1]
			}
			msg = msg_parts[0]
		}
		ctx := req.Context()
		ctx = context.WithValue(ctx, "flash", msg)
		ctx = context.WithValue(ctx, "flash_class", flash_class)
		http.SetCookie(w, clear)
		next.ServeHTTP(w, req.WithContext(ctx))
	})
}

func SetFlashMessage(w http.ResponseWriter, msg, class string) {
	enc_message := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s;_;%s", msg, class)))
	flash_cookie := &http.Cookie{Name: "flashmsg", Value: enc_message}
	http.SetCookie(w, flash_cookie)
}