diff --git a/internal/file/handlers.go b/internal/file/handlers.go index 5dd331e..3bba5b9 100644 --- a/internal/file/handlers.go +++ b/internal/file/handlers.go @@ -3,6 +3,7 @@ package file import ( "fmt" "net/http" + "path/filepath" "strconv" "time" @@ -64,9 +65,49 @@ func (h *Handler) Upload(c *gin.Context) { "filename": record.Filename, "size": record.Size, "expires_at": record.ExpiresAt, + "view_key": record.ViewID, }) } +func (h *Handler) View(c *gin.Context) { + id := c.Param("id") + + record, err := h.service.DownloadFile(id) + if err != nil { + c.HTML(http.StatusOK, "fileNotFound.html", nil) + return + } + + c.Header("Content-Disposition", fmt.Sprintf(`inline; filename="%s"`, record.Filename)) + c.Header("X-Content-Type-Options", "nosniff") + c.File(record.Path) +} + +func safeFilename(name string) string { + // keep it simple: drop control chars and quotes + out := make([]rune, 0, len(name)) + for _, r := range name { + if r < 32 || r == 127 || r == '"' || r == '\\' { + continue + } + out = append(out, r) + } + if len(out) == 0 { + return "file" + } + return string(out) +} + +func isXSSRisk(filename string) bool { + ext := filepath.Ext(filename) + switch ext { + case ".html", ".htm", ".js", ".css", ".svg": + return true + default: + return false + } +} + func (h *Handler) Download(c *gin.Context) { id := c.Param("id") @@ -75,7 +116,11 @@ func (h *Handler) Download(c *gin.Context) { c.HTML(http.StatusOK, "fileNotFound.html", nil) return } + c.Header("Content-Disposition", fmt.Sprintf(`inline; filename="%s"`, record.Filename)) + c.Header("X-Content-Type-Options", "nosniff") + //c.Header("Content-Security-Policy", "default-src 'none'; img-src 'self'; media-src 'self'; script-src 'none'; style-src 'none';") + //c.Header("Content-Type", "application/octet-stream") c.File(record.Path) } diff --git a/internal/file/repository.go b/internal/file/repository.go index 47d94bc..7d66f42 100644 --- a/internal/file/repository.go +++ b/internal/file/repository.go @@ -69,6 +69,21 @@ func (r *Repository) GetByDeletionID(delID string) (*FileRecord, error) { return &f, nil } +func (r *Repository) Update(f *FileRecord) error { + return r.db.Save(f).Error +} + +func (r *Repository) GetFileByViewID(viewID string) (*FileRecord, error) { + var f FileRecord + if err := r.db.First(&f, "view_id = ?", viewID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrFileNotFound + } + return nil, err + } + return &f, nil +} + func (r *Repository) IncrementDownload(f *FileRecord) error { f.DownloadCount++ return r.db.Save(f).Error diff --git a/internal/file/routes.go b/internal/file/routes.go index fd37202..8c2234e 100644 --- a/internal/file/routes.go +++ b/internal/file/routes.go @@ -10,8 +10,9 @@ func RegisterRoutes(r *gin.RouterGroup, h *Handler) { files := r.Group("/files") files.POST("/upload", h.Upload) - files.GET("/download/:id", h.Download) + //files.GET("/download/:id", h.Download) + files.GET("/view/:id", h.View) files.GET("/delete/:del_id", h.Delete) adminRoutes := files.Group("/admin") diff --git a/internal/file/service.go b/internal/file/service.go index a65f39c..4a88f1c 100644 --- a/internal/file/service.go +++ b/internal/file/service.go @@ -47,6 +47,7 @@ func (s *Service) UploadFile(filename string, data io.Reader, deleteAfterDownloa f := &FileRecord{ ID: folderID, DeletionID: uuid.NewString(), + ViewID: uuid.NewString(), Filename: filename, Path: path, Size: size, @@ -145,6 +146,10 @@ func (s *Service) GetFileByDeletionID(delID string) (*FileRecord, error) { return s.repo.GetByDeletionID(delID) } +func (s *Service) GetFileByViewID(viewID string) (*FileRecord, error) { + return s.repo.GetFileByViewID(viewID) +} + func (s *Service) ImportFiles(records []ImportFileRecord) error { for _, r := range records { diff --git a/internal/web/handler.go b/internal/web/handler.go index 12d456e..483e3e0 100644 --- a/internal/web/handler.go +++ b/internal/web/handler.go @@ -33,6 +33,25 @@ func (h *Handler) LoginPage(c *gin.Context) { c.HTML(200, "login.html", nil) } +func (h *Handler) FileView(c *gin.Context) { + id := c.Param("id") + + fileRecord, err := h.fileService.GetFileByViewID(id) + if err != nil { + c.HTML(404, "fileNotFound.html", nil) + return + } + + downloadKey := fileRecord.ID + deleteKey := fileRecord.DeletionID + + c.HTML(200, "complete.html", gin.H{ + "Filename": fileRecord.Filename, + "DownloadID": downloadKey, + "DeleteID": deleteKey, + }) +} + func (h *Handler) AdminPage(c *gin.Context) { pageStr := c.Query("page") page, err := strconv.Atoi(pageStr) diff --git a/internal/web/routes.go b/internal/web/routes.go index c436a7b..fe954a6 100644 --- a/internal/web/routes.go +++ b/internal/web/routes.go @@ -12,6 +12,8 @@ func RegisterRoutes(r *gin.Engine, h *Handler, userService *user.Service) { //r.GET("/upload", h.UploadPage) r.GET("/login", h.LoginPage) + r.GET("/f/:id", h.FileView) + adminRoutes := r.Group("/") adminRoutes.Use(middleware.AuthMiddleware()) adminRoutes.Use(middleware.RequireRole("admin")) diff --git a/templates/complete.html b/templates/complete.html new file mode 100644 index 0000000..1ccab5b --- /dev/null +++ b/templates/complete.html @@ -0,0 +1,132 @@ + + +
+ + +
+
+ + A service by Brammie15 +
+ +