diff --git a/internal/file/routes.go b/internal/file/routes.go index 8c2234e..690ae7c 100644 --- a/internal/file/routes.go +++ b/internal/file/routes.go @@ -10,6 +10,7 @@ func RegisterRoutes(r *gin.RouterGroup, h *Handler) { files := r.Group("/files") files.POST("/upload", h.Upload) + files.POST("/upload-multi", h.UploadMulti) //files.GET("/download/:id", h.Download) files.GET("/view/:id", h.View) diff --git a/internal/file/upload_multi.go b/internal/file/upload_multi.go new file mode 100644 index 0000000..a240d0b --- /dev/null +++ b/internal/file/upload_multi.go @@ -0,0 +1,57 @@ +package file + +import ( + "net/http" + "strconv" + "time" + + "github.com/gin-gonic/gin" +) + +// UploadMulti accepts up to 10 files, zips them server-side, and returns a single download/view key. +func (h *Handler) UploadMulti(c *gin.Context) { + if err := c.Request.ParseMultipartForm(0); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid multipart form"}) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing files"}) + return + } + if len(files) > 10 { + c.JSON(http.StatusBadRequest, gin.H{"error": "too many files (max 10)"}) + return + } + + once := c.PostForm("once") == "true" + + durationStr := c.PostForm("duration") + hours, err := strconv.Atoi(durationStr) + if err != nil || hours <= 0 { + hours = 24 + } + duration := time.Duration(hours) * time.Hour + + record, err := h.service.UploadBundle(files, once, duration) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "id": record.ID, + "deletion_id": record.DeletionID, + "filename": record.Filename, + "size": record.Size, + "expires_at": record.ExpiresAt, + "view_key": record.ViewID, + }) +} diff --git a/internal/file/zip.go b/internal/file/zip.go new file mode 100644 index 0000000..df73ced --- /dev/null +++ b/internal/file/zip.go @@ -0,0 +1,151 @@ +package file + +import ( + "archive/zip" + "errors" + "fmt" + "io" + "mime/multipart" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" +) + +func safeZipName(name string) string { + name = filepath.Base(name) + name = strings.ReplaceAll(name, "\\", "_") + name = strings.ReplaceAll(name, "/", "_") + name = strings.TrimSpace(name) + if name == "" || name == "." { + return "file" + } + return name +} + +func dedupeName(name string, seen map[string]int) string { + if _, ok := seen[name]; !ok { + seen[name] = 1 + return name + } + seen[name]++ + ext := filepath.Ext(name) + base := strings.TrimSuffix(name, ext) + return fmt.Sprintf("%s (%d)%s", base, seen[name], ext) +} + +// UploadBundle zips multiple uploaded files into a single .zip stored on disk and tracked as one FileRecord. +func (s *Service) UploadBundle(files []*multipart.FileHeader, deleteAfterDownload bool, expiresAfter time.Duration) (*FileRecord, error) { + if len(files) == 0 { + return nil, errors.New("no files") + } + if len(files) > 10 { + return nil, errors.New("too many files (max 10)") + } + + folderID := uuid.NewString() + folderPath := filepath.Join(s.storageDir, folderID) + if err := os.MkdirAll(folderPath, os.ModePerm); err != nil { + return nil, err + } + + zipDiskName := uuid.NewString() + ".zip" + zipPath := filepath.Join(folderPath, zipDiskName) + + out, err := os.Create(zipPath) + if err != nil { + return nil, err + } + defer func() { + _ = out.Close() + if err != nil { + _ = os.Remove(zipPath) + } + }() + + zw := zip.NewWriter(out) + defer func() { _ = zw.Close() }() + + seen := map[string]int{} + for _, fh := range files { + rc, openErr := fh.Open() + if openErr != nil { + err = openErr + return nil, openErr + } + + name := dedupeName(safeZipName(fh.Filename), seen) + h, _ := zip.FileInfoHeader(dummyFileInfo{name: name, size: fh.Size, mod: time.Now()}) + h.Name = name + h.Method = zip.Deflate + + w, createErr := zw.CreateHeader(h) + if createErr != nil { + _ = rc.Close() + err = createErr + return nil, createErr + } + + if _, copyErr := io.Copy(w, rc); copyErr != nil { + _ = rc.Close() + err = copyErr + return nil, copyErr + } + _ = rc.Close() + } + + if closeErr := zw.Close(); closeErr != nil { + err = closeErr + return nil, closeErr + } + if closeErr := out.Close(); closeErr != nil { + err = closeErr + return nil, closeErr + } + + zipDisplayName := fmt.Sprintf("bundle-%d-files.zip", len(files)) + + f := &FileRecord{ + ID: folderID, + DeletionID: uuid.NewString(), + ViewID: uuid.NewString(), + Filename: zipDisplayName, + Path: zipPath, + Size: fileSize(zipPath), + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(expiresAfter), + DeleteAfterDownload: deleteAfterDownload, + } + + if err := s.repo.Create(f); err != nil { + return nil, err + } + + return f, nil +} + +func fileSize(path string) int64 { + st, err := os.Stat(path) + if err != nil { + return 0 + } + return st.Size() +} + +// dummyFileInfo provides minimal os.FileInfo for zip headers. +// This avoids relying on the underlying uploaded file having a real modtime. +// (zip.Writer can work without this too, but headers look nicer.) +type dummyFileInfo struct { + name string + size int64 + mod time.Time +} + +func (d dummyFileInfo) Name() string { return d.name } +func (d dummyFileInfo) Size() int64 { return d.size } +func (d dummyFileInfo) Mode() os.FileMode { return 0o644 } +func (d dummyFileInfo) ModTime() time.Time { return d.mod } +func (d dummyFileInfo) IsDir() bool { return false } +func (d dummyFileInfo) Sys() any { return nil } diff --git a/templates/index.html b/templates/index.html index 8386fb9..93225b9 100644 --- a/templates/index.html +++ b/templates/index.html @@ -102,7 +102,7 @@