diff --git a/database/database.go b/database/database.go index 14616fa..e17bbe5 100644 --- a/database/database.go +++ b/database/database.go @@ -1,15 +1,20 @@ package database import ( + "crypto/md5" + "encoding/hex" "errors" "fmt" "github.com/google/uuid" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" + "io" "log" + "mime/multipart" "opensavecloudserver/config" "os" + "path" "sync" "time" ) @@ -115,7 +120,7 @@ func CreateGame(userId int, name string) (*Game, error) { game := &Game{ Name: name, Revision: 0, - PathStorage: gameUUID.String(), + PathStorage: gameUUID.String() + ".bin", UserId: userId, Available: false, } @@ -135,14 +140,74 @@ func AskForUpload(userId, gameId int) (*GameUploadToken, error) { } if _, ok := locks[gameId]; !ok { token := uuid.New() - return &GameUploadToken{ + lock := GameUploadToken{ GameId: gameId, UploadToken: token.String(), - }, nil + } + locks[gameId] = lock + return &lock, nil } return nil, errors.New("game already locked") } +func CheckUploadToken(uploadToken string) (int, bool) { + mu.Lock() + defer mu.Unlock() + for _, lock := range locks { + if lock.UploadToken == uploadToken { + return lock.GameId, true + } + } + return -1, false +} + +func UploadSave(file multipart.File, game *Game) error { + filePath := path.Join(config.Path().Storage, game.PathStorage) + f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0666) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, file) + if err != nil { + return err + } + return nil +} + +func UpdateGameRevision(game *Game) error { + filePath := path.Join(config.Path().Storage, game.PathStorage) + file, err := os.Open(filePath) + if err != nil { + return err + } + defer file.Close() + + hash := md5.New() + _, err = io.Copy(hash, file) + if err != nil { + return err + } + sum := hash.Sum(nil) + game.Revision += 1 + if game.Hash == nil { + game.Hash = new(string) + } + *game.Hash = hex.EncodeToString(sum) + game.Available = true + if game.LastUpdate == nil { + game.LastUpdate = new(time.Time) + } + *game.LastUpdate = time.Now() + return nil +} + +func UnlockGame(gameId int) { + mu.Lock() + defer mu.Unlock() + delete(locks, gameId) +} + // clearLocks clear lock of zombi upload func clearLocks() { mu.Lock() diff --git a/server/data.go b/server/data.go index 0182d8f..4ce3506 100644 --- a/server/data.go +++ b/server/data.go @@ -2,10 +2,13 @@ package server import ( "encoding/json" + "github.com/go-chi/chi/v5" "io" "log" "net/http" "opensavecloudserver/database" + "strconv" + "time" ) type NewGameInfo struct { @@ -48,6 +51,27 @@ func CreateGame(w http.ResponseWriter, r *http.Request) { ok(game, w, r) } +func GameInfoByID(w http.ResponseWriter, r *http.Request) { + userId, err := userIdFromContext(r.Context()) + if err != nil { + internalServerError(w, r) + return + } + queryId := chi.URLParam(r, "id") + id, err := strconv.Atoi(queryId) + if err != nil { + badRequest("Game ID missing or not an int", w, r) + return + } + game, err := database.GameInfoById(userId, id) + if err != nil { + internalServerError(w, r) + log.Println(err) + return + } + ok(game, w, r) +} + func AskForUpload(w http.ResponseWriter, r *http.Request) { userId, err := userIdFromContext(r.Context()) if err != nil { @@ -74,3 +98,47 @@ func AskForUpload(w http.ResponseWriter, r *http.Request) { } ok(token, w, r) } + +func UploadSave(w http.ResponseWriter, r *http.Request) { + userId, err := userIdFromContext(r.Context()) + if err != nil { + internalServerError(w, r) + return + } + gameId, err := gameIdFromContext(r.Context()) + if err != nil { + internalServerError(w, r) + return + } + defer database.UnlockGame(gameId) + game, err := database.GameInfoById(userId, gameId) + if err != nil { + internalServerError(w, r) + return + } + file, _, err := r.FormFile("file") + if err != nil { + internalServerError(w, r) + log.Println(err) + return + } + defer file.Close() + err = database.UploadSave(file, game) + if err != nil { + internalServerError(w, r) + log.Println(err) + return + } + err = database.UpdateGameRevision(game) + if err != nil { + internalServerError(w, r) + log.Println(err) + return + } + payload := &successMessage{ + Message: "Game uploaded", + Timestamp: time.Now(), + Status: 200, + } + ok(payload, w, r) +} diff --git a/server/server.go b/server/server.go index 60d3f86..b7ee617 100644 --- a/server/server.go +++ b/server/server.go @@ -9,11 +9,15 @@ import ( "net/http" "opensavecloudserver/authentication" "opensavecloudserver/config" + "opensavecloudserver/database" ) type ContextKey string -const UserIdKey ContextKey = "userId" +const ( + UserIdKey ContextKey = "userId" + GameIdKey ContextKey = "gameId" +) // Serve start the http server func Serve() { @@ -30,11 +34,17 @@ func Serve() { r.Route("/system", func(systemRouter chi.Router) { systemRouter.Get("/information", Information) }) - r.Group(func(secureRouter chi.Router) { + r.Route("/game", func(secureRouter chi.Router) { secureRouter.Use(authMiddleware) - secureRouter.Post("/game/create", CreateGame) - secureRouter.Post("/game/upload/init", AskForUpload) + secureRouter.Post("/create", CreateGame) + secureRouter.Get("/{id}", GameInfoByID) + secureRouter.Post("/upload/init", AskForUpload) + secureRouter.Group(func(uploadRouter chi.Router) { + uploadRouter.Use(uploadMiddleware) + uploadRouter.Post("/upload", UploadSave) + }) }) + }) }) log.Println("Server is listening...") @@ -44,7 +54,7 @@ func Serve() { } } -// authMiddleware filter the request +// authMiddleware check the authentication token before accessing to the resource func authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { header := r.Header.Get("Authorization") @@ -63,6 +73,22 @@ func authMiddleware(next http.Handler) http.Handler { }) } +// uploadMiddleware check the upload key before allowing to upload a file +func uploadMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header := r.Header.Get("X-Upload-Key") + if len(header) > 0 { + if gameId, ok := database.CheckUploadToken(header); ok { + ctx := context.WithValue(r.Context(), GameIdKey, gameId) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + return + } + } + unauthorized(w, r) + }) +} + func userIdFromContext(ctx context.Context) (int, error) { if userId, ok := ctx.Value(UserIdKey).(int); ok { return userId, nil @@ -70,6 +96,13 @@ func userIdFromContext(ctx context.Context) (int, error) { return 0, errors.New("userId not found in context") } +func gameIdFromContext(ctx context.Context) (int, error) { + if gameId, ok := ctx.Value(GameIdKey).(int); ok { + return gameId, nil + } + return 0, errors.New("gameId not found in context") +} + func recovery(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() {