Files
backroad/api/http/handler/extensions/extension_create.go
Anthony Lapenna d5cee5b8b1 feat(core/extensions): add the ability to update a license (#4081)
* feat(core/extensions): add the ability to update a license

* feat(core/extensions): trigger data upgrade if extension is not enabled yet

* feat(core/extensions): trigger data upgrade if extension is not enabled yet

* feat(core/extensions): trigger data upgrade if extension is not enabled yet

* feat(core/extensions): trigger data upgrade if extension is not enabled yet
2020-07-22 21:13:51 +12:00

91 lines
2.7 KiB
Go

package extensions
import (
"net/http"
"strconv"
"github.com/asaskevich/govalidator"
httperror "github.com/portainer/libhttp/error"
"github.com/portainer/libhttp/request"
"github.com/portainer/libhttp/response"
"github.com/portainer/portainer/api"
)
type extensionCreatePayload struct {
License string
}
func (payload *extensionCreatePayload) Validate(r *http.Request) error {
if govalidator.IsNull(payload.License) {
return portainer.Error("Invalid license")
}
return nil
}
func (handler *Handler) extensionCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload extensionCreatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload)
if err != nil {
return &httperror.HandlerError{http.StatusBadRequest, "Invalid request payload", err}
}
extensionIdentifier, err := strconv.Atoi(string(payload.License[0]))
if err != nil {
return &httperror.HandlerError{http.StatusBadRequest, "Invalid license format", err}
}
extensionID := portainer.ExtensionID(extensionIdentifier)
extensions, err := handler.ExtensionService.Extensions()
if err != nil {
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to retrieve extensions status from the database", err}
}
extension := &portainer.Extension{
ID: extensionID,
}
for _, existingExtension := range extensions {
if existingExtension.ID == extensionID && (existingExtension.Enabled || !existingExtension.License.Valid) {
if existingExtension.License.LicenseKey == payload.License {
return &httperror.HandlerError{http.StatusConflict, "Unable to enable extension", portainer.ErrExtensionAlreadyEnabled}
}
_ = handler.ExtensionManager.DisableExtension(&existingExtension)
extension.Enabled = true
}
}
extensionDefinitions, err := handler.ExtensionManager.FetchExtensionDefinitions()
if err != nil {
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to retrieve extension definitions", err}
}
for _, def := range extensionDefinitions {
if def.ID == extension.ID {
extension.Version = def.Version
break
}
}
err = handler.ExtensionManager.EnableExtension(extension, payload.License)
if err != nil {
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to enable extension", err}
}
if extension.ID == portainer.RBACExtension && !extension.Enabled {
err = handler.upgradeRBACData()
if err != nil {
return &httperror.HandlerError{http.StatusInternalServerError, "An error occured during database update", err}
}
}
extension.Enabled = true
err = handler.ExtensionService.Persist(extension)
if err != nil {
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist extension status inside the database", err}
}
return response.Empty(w)
}