Compare commits
54 Commits
develop
...
release/2.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb9f93f5fc | ||
|
|
a69470ec08 | ||
|
|
ea6f1c97f5 | ||
|
|
6d058987f3 | ||
|
|
6998f05855 | ||
|
|
94d01c58fc | ||
|
|
d98eb77067 | ||
|
|
941e86563a | ||
|
|
f72d6b97d3 | ||
|
|
32926aa8bf | ||
|
|
1849c61c38 | ||
|
|
fd6d74602c | ||
|
|
74b1dd04d1 | ||
|
|
7450501b7a | ||
|
|
dcfe2d9809 | ||
|
|
c21c91632f | ||
|
|
732337615e | ||
|
|
6ea16c0060 | ||
|
|
4e7d4b60a5 | ||
|
|
19e1cc2fbd | ||
|
|
68b9fef3f0 | ||
|
|
1e47df6611 | ||
|
|
405ce8f671 | ||
|
|
e9d31b3b7b | ||
|
|
f97adc94ad | ||
|
|
11d6341765 | ||
|
|
c3cf46b0e0 | ||
|
|
ff746beba1 | ||
|
|
da1672fc17 | ||
|
|
7a9376cbaf | ||
|
|
c0f6410d80 | ||
|
|
4b9ab98fd2 | ||
|
|
3354ee4e4b | ||
|
|
af3c45bea0 | ||
|
|
816a6f9bef | ||
|
|
e86ea22900 | ||
|
|
12b2acbc00 | ||
|
|
4a8b42928e | ||
|
|
2e828b39da | ||
|
|
49c6521c23 | ||
|
|
debf1a742b | ||
|
|
5d3708ec3e | ||
|
|
9320fd4c50 | ||
|
|
974682bd98 | ||
|
|
631f1deb2e | ||
|
|
4169b045fb | ||
|
|
0a2a786aa3 | ||
|
|
808f87206e | ||
|
|
ed6fa82904 | ||
|
|
9fc301110b | ||
|
|
69101ac89a | ||
|
|
69d33dd432 | ||
|
|
389cbf748c | ||
|
|
d01b31f707 |
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
// Service implements the CLIService interface
|
||||
@@ -35,16 +35,9 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(),
|
||||
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
|
||||
NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(),
|
||||
TLS: kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).Bool(),
|
||||
TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(),
|
||||
TLSCacert: kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String(),
|
||||
TLSCert: kingpin.Flag("tlscert", "Path to the TLS certificate file").Default(defaultTLSCertPath).String(),
|
||||
TLSKey: kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).String(),
|
||||
HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(),
|
||||
HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(),
|
||||
SSL: kingpin.Flag("ssl", "Secure Portainer instance using SSL (deprecated)").Default(defaultSSL).Bool(),
|
||||
SSLCert: kingpin.Flag("sslcert", "Path to the SSL certificate used to secure the Portainer instance").String(),
|
||||
SSLKey: kingpin.Flag("sslkey", "Path to the SSL key used to secure the Portainer instance").String(),
|
||||
Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(),
|
||||
SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(),
|
||||
AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(),
|
||||
@@ -63,6 +56,7 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(),
|
||||
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
|
||||
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
|
||||
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,8 +64,37 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
var hasSSLFlag, hasSSLCertFlag, hasSSLKeyFlag bool
|
||||
sslFlag := kingpin.Flag(
|
||||
"ssl",
|
||||
"Secure Portainer instance using SSL (deprecated)",
|
||||
).Default(defaultSSL).IsSetByUser(&hasSSLFlag)
|
||||
ssl := sslFlag.Bool()
|
||||
sslCertFlag := kingpin.Flag(
|
||||
"sslcert",
|
||||
"Path to the SSL certificate used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLCertFlag)
|
||||
sslCert := sslCertFlag.String()
|
||||
sslKeyFlag := kingpin.Flag(
|
||||
"sslkey",
|
||||
"Path to the SSL key used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLKeyFlag)
|
||||
sslKey := sslKeyFlag.String()
|
||||
|
||||
flags := CLIFlags()
|
||||
|
||||
var hasTLSFlag, hasTLSCertFlag, hasTLSKeyFlag bool
|
||||
tlsFlag := kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).IsSetByUser(&hasTLSFlag)
|
||||
flags.TLS = tlsFlag.Bool()
|
||||
tlsCertFlag := kingpin.Flag(
|
||||
"tlscert",
|
||||
"Path to the TLS certificate file",
|
||||
).Default(defaultTLSCertPath).IsSetByUser(&hasTLSCertFlag)
|
||||
flags.TLSCert = tlsCertFlag.String()
|
||||
tlsKeyFlag := kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).IsSetByUser(&hasTLSKeyFlag)
|
||||
flags.TLSKey = tlsKeyFlag.String()
|
||||
flags.TLSCacert = kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String()
|
||||
|
||||
kingpin.Parse()
|
||||
|
||||
if !filepath.IsAbs(*flags.Assets) {
|
||||
@@ -83,6 +106,41 @@ func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
*flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets)
|
||||
}
|
||||
|
||||
// If the user didn't provide a tls flag remove the defaults to match previous behaviour
|
||||
if !hasTLSFlag {
|
||||
if !hasTLSCertFlag {
|
||||
*flags.TLSCert = ""
|
||||
}
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
*flags.TLSKey = ""
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslFlag.Model().Name, tlsFlag.Model().Name)
|
||||
|
||||
if !hasTLSFlag {
|
||||
flags.TLS = ssl
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLCertFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslCertFlag.Model().Name, tlsCertFlag.Model().Name)
|
||||
|
||||
if !hasTLSCertFlag {
|
||||
flags.TLSCert = sslCert
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLKeyFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslKeyFlag.Model().Name, tlsKeyFlag.Model().Name)
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
flags.TLSKey = sslKey
|
||||
}
|
||||
}
|
||||
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
@@ -109,10 +167,6 @@ func displayDeprecationWarnings(flags *portainer.CLIFlags) {
|
||||
if *flags.NoAnalytics {
|
||||
log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect")
|
||||
}
|
||||
|
||||
if *flags.SSL {
|
||||
log.Warn().Msg("SSL is enabled by default and there is no need for the --ssl flag, it has been kept to allow migration of instances running a previous version of Portainer with this flag enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpointURL(endpointURL string) error {
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
zerolog "github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -22,3 +25,185 @@ func TestOptionParser(t *testing.T) {
|
||||
require.False(t, *opts.HTTPDisabled)
|
||||
require.True(t, *opts.EnableEdgeComputeFeatures)
|
||||
}
|
||||
|
||||
func TestParseTLSFlags(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedTLSFlag bool
|
||||
expectedTLSCertFlag string
|
||||
expectedTLSKeyFlag string
|
||||
expectedLogMessages []string
|
||||
}{
|
||||
{
|
||||
name: "no flags",
|
||||
expectedTLSFlag: false,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only ssl flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only tls flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: defaultTLSCertPath,
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "partial tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags 2",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var logOutput strings.Builder
|
||||
setupLogOutput(t, &logOutput)
|
||||
|
||||
if tc.args == nil {
|
||||
tc.args = []string{"portainer"}
|
||||
}
|
||||
setOsArgs(t, tc.args)
|
||||
|
||||
s := Service{}
|
||||
flags, err := s.ParseFlags("test-version")
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing flags: %v", err)
|
||||
}
|
||||
|
||||
if flags.TLS == nil {
|
||||
t.Fatal("TLS flag was nil")
|
||||
}
|
||||
|
||||
require.Equal(t, tc.expectedTLSFlag, *flags.TLS, "tlsverify flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSCertFlag, *flags.TLSCert, "tlscert flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSKeyFlag, *flags.TLSKey, "tlskey flag didn't match")
|
||||
|
||||
for _, expectedLogMessage := range tc.expectedLogMessages {
|
||||
require.Contains(t, logOutput.String(), expectedLogMessage, "Log didn't contain expected message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setOsArgs(t *testing.T, args []string) {
|
||||
t.Helper()
|
||||
previousArgs := os.Args
|
||||
os.Args = args
|
||||
t.Cleanup(func() {
|
||||
os.Args = previousArgs
|
||||
})
|
||||
}
|
||||
|
||||
func setupLogOutput(t *testing.T, w io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
oldLogger := zerolog.Logger
|
||||
zerolog.Logger = zerolog.Output(w)
|
||||
t.Cleanup(func() {
|
||||
zerolog.Logger = oldLogger
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
)
|
||||
|
||||
type pairList []portainer.Pair
|
||||
|
||||
@@ -84,7 +84,7 @@ func initFileService(dataStorePath string) portainer.FileService {
|
||||
}
|
||||
|
||||
func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService portainer.FileService, shutdownCtx context.Context) dataservices.DataStore {
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey)
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey, *flags.CompactDB)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed creating database connection")
|
||||
}
|
||||
@@ -309,13 +309,13 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
|
||||
|
||||
// dbSecretPath build the path to the file that contains the db encryption
|
||||
// secret. Normally in Docker this is built from the static path inside
|
||||
// /run/portainer for example: /run/portainer/<keyFilenameFlag> but for ease of
|
||||
// /run/secrets for example: /run/secrets/<keyFilenameFlag> but for ease of
|
||||
// use outside Docker it also accepts an absolute path
|
||||
func dbSecretPath(keyFilenameFlag string) string {
|
||||
if path.IsAbs(keyFilenameFlag) {
|
||||
return keyFilenameFlag
|
||||
}
|
||||
return path.Join("/run/portainer", keyFilenameFlag)
|
||||
return path.Join("/run/secrets", keyFilenameFlag)
|
||||
}
|
||||
|
||||
func loadEncryptionSecretKey(keyfilename string) []byte {
|
||||
@@ -408,7 +408,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
|
||||
edgeStacksService := edgestacks.NewService(dataStore)
|
||||
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.SSLCert, *flags.SSLKey, fileService, dataStore, shutdownTrigger)
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.TLSCert, *flags.TLSKey, fileService, dataStore, shutdownTrigger)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
@@ -43,12 +43,12 @@ func TestDBSecretPath(t *testing.T) {
|
||||
keyFilenameFlag string
|
||||
expected string
|
||||
}{
|
||||
{keyFilenameFlag: "secret.txt", expected: "/run/portainer/secret.txt"},
|
||||
{keyFilenameFlag: "secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"},
|
||||
{keyFilenameFlag: "/run/portainer/secret.txt", expected: "/run/portainer/secret.txt"},
|
||||
{keyFilenameFlag: "./secret.txt", expected: "/run/portainer/secret.txt"},
|
||||
{keyFilenameFlag: "/run/secrets/secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "./secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"},
|
||||
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/portainer/foo/bar/secret.txt"},
|
||||
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/secrets/foo/bar/secret.txt"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/pbkdf2"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
@@ -248,7 +248,10 @@ func aesEncryptGCMFIPS(input io.Reader, output io.Writer, passphrase []byte) err
|
||||
return err
|
||||
}
|
||||
|
||||
key := pbkdf2.Key(passphrase, salt, pbkdf2Iterations, 32, sha256.New)
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -315,7 +318,10 @@ func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := pbkdf2.Key(passphrase, salt, pbkdf2Iterations, 32, sha256.New)
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
// Initialize AES cipher block
|
||||
block, err := aes.NewCipher(key)
|
||||
@@ -382,3 +388,18 @@ func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// HasEncryptedHeader checks if the data has an encrypted header, note that fips
|
||||
// mode changes this behavior and so will only recognize data encrypted by the
|
||||
// same mode (fips enabled or disabled)
|
||||
func HasEncryptedHeader(data []byte) bool {
|
||||
return hasEncryptedHeader(data, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func hasEncryptedHeader(data []byte, fipsMode bool) bool {
|
||||
if fipsMode {
|
||||
return bytes.HasPrefix(data, []byte(aesGcmFIPSHeader))
|
||||
}
|
||||
|
||||
return bytes.HasPrefix(data, []byte(aesGcmHeader))
|
||||
}
|
||||
|
||||
@@ -350,3 +350,62 @@ func legacyAesEncrypt(input io.Reader, output io.Writer, passphrase []byte) erro
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_hasEncryptedHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
fipsMode bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-FIPS mode with valid header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-FIPS mode with FIPS header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with valid header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with non-FIPS header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header",
|
||||
data: []byte("INVALID-HEADER" + "some data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := hasEncryptedHeader(tt.data, tt.fipsMode)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/fips140"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
)
|
||||
|
||||
// CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings
|
||||
func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
// TODO: use fips.FIPSMode() instead
|
||||
return createTLSConfiguration(fips140.Enabled(), insecureSkipVerify)
|
||||
return createTLSConfiguration(fips.FIPSMode(), insecureSkipVerify)
|
||||
}
|
||||
|
||||
func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
@@ -58,8 +57,7 @@ func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Conf
|
||||
// CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from memory.
|
||||
func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
// TODO: use fips.FIPSMode() instead
|
||||
return createTLSConfigurationFromBytes(fips140.Enabled(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
|
||||
return createTLSConfigurationFromBytes(fips.FIPSMode(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
|
||||
}
|
||||
|
||||
func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
@@ -90,8 +88,7 @@ func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key
|
||||
// CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from disk.
|
||||
func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
// TODO: use fips.FIPSMode() instead
|
||||
return createTLSConfigurationFromDisk(fips140.Enabled(), config)
|
||||
return createTLSConfigurationFromDisk(fips.FIPSMode(), config)
|
||||
}
|
||||
|
||||
func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
|
||||
@@ -21,6 +21,9 @@ import (
|
||||
const (
|
||||
DatabaseFileName = "portainer.db"
|
||||
EncryptedDatabaseFileName = "portainer.edb"
|
||||
|
||||
txMaxSize = 65536
|
||||
compactedSuffix = ".compacted"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -35,6 +38,7 @@ type DbConnection struct {
|
||||
InitialMmapSize int
|
||||
EncryptionKey []byte
|
||||
isEncrypted bool
|
||||
Compact bool
|
||||
|
||||
*bolt.DB
|
||||
}
|
||||
@@ -132,15 +136,8 @@ func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
|
||||
func (connection *DbConnection) Open() error {
|
||||
log.Info().Str("filename", connection.GetDatabaseFileName()).Msg("loading PortainerDB")
|
||||
|
||||
// Now we open the db
|
||||
databasePath := connection.GetDatabaseFilePath()
|
||||
|
||||
db, err := bolt.Open(databasePath, 0600, &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
FreelistType: bolt.FreelistMapType,
|
||||
NoFreelistSync: true,
|
||||
})
|
||||
db, err := bolt.Open(databasePath, 0600, connection.boltOptions(connection.Compact))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -149,6 +146,24 @@ func (connection *DbConnection) Open() error {
|
||||
db.MaxBatchDelay = connection.MaxBatchDelay
|
||||
connection.DB = db
|
||||
|
||||
if connection.Compact {
|
||||
log.Info().Msg("compacting database")
|
||||
if err := connection.compact(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to compact database")
|
||||
|
||||
// Close the read-only database and re-open in read-write mode
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after failed compaction")
|
||||
}
|
||||
|
||||
connection.Compact = false
|
||||
|
||||
return connection.Open()
|
||||
} else {
|
||||
log.Info().Msg("database compaction completed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -414,3 +429,48 @@ func (connection *DbConnection) RestoreMetadata(s map[string]any) error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// compact attempts to compact the database and replace it iff it succeeds
|
||||
func (connection *DbConnection) compact() (err error) {
|
||||
compactedPath := connection.GetDatabaseFilePath() + compactedSuffix
|
||||
|
||||
if err := os.Remove(compactedPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("failure to remove an existing compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB, err := bolt.Open(compactedPath, 0o600, connection.boltOptions(false))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failure to create the compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB.MaxBatchSize = connection.MaxBatchSize
|
||||
compactedDB.MaxBatchDelay = connection.MaxBatchDelay
|
||||
|
||||
if err := bolt.Compact(compactedDB, connection.DB, txMaxSize); err != nil {
|
||||
return fmt.Errorf("failure to compact the database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := os.Rename(compactedPath, connection.GetDatabaseFilePath()); err != nil {
|
||||
return fmt.Errorf("failure to move the compacted database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after compaction")
|
||||
}
|
||||
|
||||
connection.DB = compactedDB
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (connection *DbConnection) boltOptions(readOnly bool) *bolt.Options {
|
||||
return &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
FreelistType: bolt.FreelistMapType,
|
||||
NoFreelistSync: true,
|
||||
ReadOnly: readOnly,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,11 @@ import (
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
@@ -119,3 +123,59 @@ func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBCompaction(t *testing.T) {
|
||||
db := &DbConnection{Path: t.TempDir()}
|
||||
|
||||
err := db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *bbolt.Tx) error {
|
||||
b, err := tx.CreateBucketIfNotExists([]byte("testbucket"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Put([]byte("key"), []byte("value"))
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reopen the DB to trigger compaction
|
||||
db.Compact = true
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the data is still there
|
||||
err = db.View(func(tx *bbolt.Tx) error {
|
||||
b := tx.Bucket([]byte("testbucket"))
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := b.Get([]byte("key"))
|
||||
require.Equal(t, []byte("value"), val)
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Failures
|
||||
compactedPath := db.GetDatabaseFilePath() + compactedSuffix
|
||||
err = os.Mkdir(compactedPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := os.Create(filesystem.JoinPaths(compactedPath, "somefile"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -8,11 +8,12 @@ import (
|
||||
)
|
||||
|
||||
// NewDatabase should use config options to return a connection to the requested database
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte) (connection portainer.Connection, err error) {
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte, compact bool) (connection portainer.Connection, err error) {
|
||||
if storeType == "boltdb" {
|
||||
return &boltdb.DbConnection{
|
||||
Path: storePath,
|
||||
EncryptionKey: encryptionKey,
|
||||
Compact: compact,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -28,13 +28,12 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).Create(customTemplate)
|
||||
})
|
||||
}
|
||||
|
||||
19
api/dataservices/customtemplate/customtemplate_test.go
Normal file
19
api/dataservices/customtemplate/customtemplate_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreate(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
|
||||
e, err := ds.CustomTemplate().Read(1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, portainer.CustomTemplateID(1), e.ID)
|
||||
}
|
||||
31
api/dataservices/customtemplate/tx.go
Normal file
31
api/dataservices/customtemplate/tx.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package customtemplate
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
// Service represents a service for managing custom template data.
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]{
|
||||
Bucket: BucketName,
|
||||
Connection: service.Connection,
|
||||
Tx: tx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service ServiceTx) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Tx.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
||||
28
api/dataservices/customtemplate/tx_test.go
Normal file
28
api/dataservices/customtemplate/tx_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreateTx(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1})
|
||||
}))
|
||||
|
||||
var template *portainer.CustomTemplate
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
template, err = tx.CustomTemplate().Read(1)
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
|
||||
}
|
||||
@@ -91,9 +91,9 @@ func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID,
|
||||
})
|
||||
}
|
||||
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID)
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStack)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestack"
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -102,3 +103,38 @@ func TestUpdateRelation(t *testing.T) {
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments)
|
||||
}
|
||||
|
||||
func TestAddEndpointRelationsForEdgeStack(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
edgeStackService, err := edgestack.NewService(conn, func(t portainer.Transaction, esi portainer.EdgeStackID) {})
|
||||
require.NoError(t, err)
|
||||
|
||||
service.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
|
||||
require.NoError(t, edgeStackService.Create(1, &portainer.EdgeStack{}))
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1, EdgeStacks: map[portainer.EdgeStackID]bool{}}))
|
||||
require.NoError(t, service.AddEndpointRelationsForEdgeStack([]portainer.EndpointID{1}, &portainer.EdgeStack{ID: 1}))
|
||||
}
|
||||
|
||||
func TestEndpointRelations(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1}))
|
||||
rels, err := service.EndpointRelations()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(rels))
|
||||
}
|
||||
|
||||
@@ -76,14 +76,14 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
for _, endpointID := range endpointIDs {
|
||||
rel, err := service.EndpointRelation(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rel.EdgeStacks[edgeStackID] = true
|
||||
rel.EdgeStacks[edgeStack.ID] = true
|
||||
|
||||
identifier := service.service.connection.ConvertToKey(int(endpointID))
|
||||
err = service.tx.UpdateObject(BucketName, identifier, rel)
|
||||
@@ -97,8 +97,12 @@ func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portaine
|
||||
service.service.endpointRelationsCache = nil
|
||||
service.service.mu.Unlock()
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments += len(endpointIDs)
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStack.ID, func(es *portainer.EdgeStack) {
|
||||
es.NumDeployments += len(endpointIDs)
|
||||
|
||||
// sync changes in `edgeStack` in case it is re-persisted after `AddEndpointRelationsForEdgeStack` call
|
||||
// to avoid overriding with the previous values
|
||||
edgeStack.NumDeployments = es.NumDeployments
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ type (
|
||||
EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error)
|
||||
Create(endpointRelation *portainer.EndpointRelation) error
|
||||
UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error
|
||||
RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
DeleteEndpointRelation(EndpointID portainer.EndpointID) error
|
||||
BucketName() string
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const BucketName = "pending_actions"
|
||||
@@ -16,10 +11,6 @@ type Service struct {
|
||||
dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
err := connection.SetServiceName(BucketName)
|
||||
if err != nil {
|
||||
@@ -34,6 +25,11 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (s Service) Create(config *portainer.PendingAction) error {
|
||||
return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Create(config)
|
||||
@@ -61,43 +57,3 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
if err := s.Delete(pendingAction.ID); err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
49
api/dataservices/pendingactions/tx.go
Normal file
49
api/dataservices/pendingactions/tx.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
if err := s.Delete(pendingAction.ID); err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
@@ -11,8 +11,10 @@ func (m *Migrator) migrateEdgeGroupEndpointsToRoars_2_33_0() error {
|
||||
}
|
||||
|
||||
for _, eg := range egs {
|
||||
eg.EndpointIDs = roar.FromSlice(eg.Endpoints)
|
||||
eg.Endpoints = nil
|
||||
if eg.EndpointIDs.Len() == 0 {
|
||||
eg.EndpointIDs = roar.FromSlice(eg.Endpoints)
|
||||
eg.Endpoints = nil
|
||||
}
|
||||
|
||||
if err := m.edgeGroupService.Update(eg.ID, &eg); err != nil {
|
||||
return err
|
||||
|
||||
55
api/datastore/migrator/migrate_2_33_test.go
Normal file
55
api/datastore/migrator/migrate_2_33_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/dataservices/edgegroup"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateEdgeGroupEndpointsToRoars_2_33_0Idempotency(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
edgeGroupService, err := edgegroup.NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
edgeGroup := &portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "test-edge-group",
|
||||
Endpoints: []portainer.EndpointID{1, 2, 3},
|
||||
}
|
||||
|
||||
err = conn.CreateObjectWithId(edgegroup.BucketName, int(edgeGroup.ID), edgeGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
m := NewMigrator(&MigratorParameters{EdgeGroupService: edgeGroupService})
|
||||
|
||||
// Run migration once
|
||||
|
||||
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
|
||||
require.NoError(t, err)
|
||||
|
||||
migratedEdgeGroup, err := edgeGroupService.Read(edgeGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, migratedEdgeGroup.Endpoints, 0)
|
||||
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
|
||||
|
||||
// Run migration again to ensure the results didn't change
|
||||
|
||||
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
|
||||
require.NoError(t, err)
|
||||
|
||||
migratedEdgeGroup, err = edgeGroupService.Read(edgeGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, migratedEdgeGroup.Endpoints, 0)
|
||||
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
|
||||
}
|
||||
@@ -256,10 +256,7 @@ func (m *Migrator) initMigrations() {
|
||||
|
||||
m.addMigrations("2.32.0", m.addEndpointRelationForEdgeAgents_2_32_0)
|
||||
|
||||
m.addMigrations("2.33.0-rc1", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
|
||||
|
||||
//m.addMigrations("2.33.0", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
|
||||
// when we release 2.33.0 it will also run the rc-1 migration function
|
||||
m.addMigrations("2.33.1", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
|
||||
|
||||
// Add new migrations above...
|
||||
// One function per migration, each versions migration funcs in the same file.
|
||||
|
||||
@@ -2,6 +2,7 @@ package postinit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
@@ -83,17 +84,27 @@ func (postInitMigrator *PostInitMigrator) PostInitMigrate() error {
|
||||
|
||||
// try to create a post init migration pending action. If it already exists, do nothing
|
||||
// this function exists for readability, not reusability
|
||||
// TODO: This should be moved into pending actions as part of the pending action migration
|
||||
func (postInitMigrator *PostInitMigrator) createPostInitMigrationPendingAction(environmentID portainer.EndpointID) error {
|
||||
// If there are no pending actions for the given endpoint, create one
|
||||
err := postInitMigrator.dataStore.PendingActions().Create(&portainer.PendingAction{
|
||||
action := portainer.PendingAction{
|
||||
EndpointID: environmentID,
|
||||
Action: actions.PostInitMigrateEnvironment,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error creating pending action for environment %d", environmentID)
|
||||
}
|
||||
return nil
|
||||
pendingActions, err := postInitMigrator.dataStore.PendingActions().ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending actions: %w", err)
|
||||
}
|
||||
|
||||
for _, dba := range pendingActions {
|
||||
if dba.EndpointID == action.EndpointID && dba.Action == action.Action {
|
||||
log.Debug().
|
||||
Str("action", action.Action).
|
||||
Int("endpoint_id", int(action.EndpointID)).
|
||||
Msg("pending action already exists for environment, skipping...")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return postInitMigrator.dataStore.PendingActions().Create(&action)
|
||||
}
|
||||
|
||||
// MigrateEnvironment runs migrations on a single environment
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/pendingactions/actions"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -73,3 +75,96 @@ func TestMigrateGPUs(t *testing.T) {
|
||||
require.False(t, migratedEndpoint.PostInitMigrations.MigrateGPUs)
|
||||
require.True(t, migratedEndpoint.EnableGPUManagement)
|
||||
}
|
||||
|
||||
func TestPostInitMigrate_PendingActionsCreated(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
existingPendingActions []*portainer.PendingAction
|
||||
expectedPendingActions int
|
||||
expectedAction string
|
||||
}{
|
||||
{
|
||||
name: "when existing non-matching action exists, should add migration action",
|
||||
existingPendingActions: []*portainer.PendingAction{
|
||||
{
|
||||
EndpointID: 7,
|
||||
Action: "some-other-action",
|
||||
},
|
||||
},
|
||||
expectedPendingActions: 2,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
{
|
||||
name: "when matching action exists, should not add duplicate",
|
||||
existingPendingActions: []*portainer.PendingAction{
|
||||
{
|
||||
EndpointID: 7,
|
||||
Action: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
},
|
||||
expectedPendingActions: 1,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
{
|
||||
name: "when no actions exist, should add migration action",
|
||||
existingPendingActions: []*portainer.PendingAction{},
|
||||
expectedPendingActions: 1,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
// Create test endpoint
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: 7,
|
||||
UserTrusted: true,
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
Edge: portainer.EnvironmentEdgeSettings{
|
||||
AsyncMode: false,
|
||||
},
|
||||
EdgeID: "edgeID",
|
||||
}
|
||||
err := store.Endpoint().Create(endpoint)
|
||||
is.NoError(err, "error creating endpoint")
|
||||
|
||||
// Create any existing pending actions
|
||||
for _, action := range tt.existingPendingActions {
|
||||
err = store.PendingActions().Create(action)
|
||||
is.NoError(err, "error creating pending action")
|
||||
}
|
||||
|
||||
migrator := NewPostInitMigrator(
|
||||
nil, // kubeFactory not needed for this test
|
||||
nil, // dockerFactory not needed for this test
|
||||
store,
|
||||
"", // assetsPath not needed for this test
|
||||
nil, // kubernetesDeployer not needed for this test
|
||||
)
|
||||
|
||||
err = migrator.PostInitMigrate()
|
||||
is.NoError(err, "PostInitMigrate should not return error")
|
||||
|
||||
// Verify the results
|
||||
pendingActions, err := store.PendingActions().ReadAll()
|
||||
is.NoError(err, "error reading pending actions")
|
||||
is.Len(pendingActions, tt.expectedPendingActions, "unexpected number of pending actions")
|
||||
|
||||
// If we expect any actions, verify at least one has the expected action type
|
||||
if tt.expectedPendingActions > 0 {
|
||||
hasExpectedAction := false
|
||||
for _, action := range pendingActions {
|
||||
if action.Action == tt.expectedAction {
|
||||
hasExpectedAction = true
|
||||
is.Equal(endpoint.ID, action.EndpointID, "action should reference correct endpoint")
|
||||
break
|
||||
}
|
||||
}
|
||||
is.True(hasExpectedAction, "should have found action of expected type")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,9 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
|
||||
return tx.store.IsErrObjectNotFound(err)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return nil }
|
||||
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService {
|
||||
return tx.store.CustomTemplateService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) PendingActions() dataservices.PendingActionsService {
|
||||
return tx.store.PendingActionsService.Tx(tx.tx)
|
||||
|
||||
@@ -615,7 +615,7 @@
|
||||
"RequiredPasswordLength": 12
|
||||
},
|
||||
"KubeconfigExpiry": "0",
|
||||
"KubectlShellImage": "portainer/kubectl-shell:2.33.0-rc1",
|
||||
"KubectlShellImage": "portainer/kubectl-shell:2.33.4",
|
||||
"LDAPSettings": {
|
||||
"AnonymousMode": true,
|
||||
"AutoCreateUsers": true,
|
||||
@@ -944,7 +944,7 @@
|
||||
}
|
||||
],
|
||||
"version": {
|
||||
"VERSION": "{\"SchemaVersion\":\"2.33.0-rc1\",\"MigratorCount\":1,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
|
||||
"VERSION": "{\"SchemaVersion\":\"2.33.4\",\"MigratorCount\":0,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
|
||||
},
|
||||
"webhooks": null
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func NewTestStore(t testing.TB, init, secure bool) (bool, *Store, func(), error)
|
||||
secretKey = nil
|
||||
}
|
||||
|
||||
connection, err := database.NewDatabase("boltdb", storePath, secretKey)
|
||||
connection, err := database.NewDatabase("boltdb", storePath, secretKey, false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -4,10 +4,14 @@ import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHttpClient(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
// Valid TLS configuration
|
||||
endpoint := &portainer.Endpoint{}
|
||||
endpoint.TLSConfig = portainer.TLSConfiguration{TLS: true}
|
||||
|
||||
@@ -49,6 +49,11 @@ type (
|
||||
|
||||
// Is relative path supported
|
||||
SupportRelativePath bool
|
||||
// AlwaysCloneGitRepoForRelativePath is a flag indicating if the agent must always clone the git repository for relative path.
|
||||
// This field is only valid when SupportRelativePath is true.
|
||||
// Used only for EE
|
||||
AlwaysCloneGitRepoForRelativePath bool
|
||||
|
||||
// Mount point for relative path
|
||||
FilesystemPath string
|
||||
// Used only for EE
|
||||
|
||||
@@ -848,7 +848,7 @@ func defaultMTLSCertPathUnderFileStore() (string, string, string) {
|
||||
return caCertPath, certPath, keyPath
|
||||
}
|
||||
|
||||
// GetDefaultChiselPrivateKeyPath returns the chisle private key path
|
||||
// GetDefaultChiselPrivateKeyPath returns the chisel private key path
|
||||
func (service *Service) GetDefaultChiselPrivateKeyPath() string {
|
||||
privateKeyPath := defaultChiselPrivateKeyPathUnderFileStore()
|
||||
return service.wrapFileStore(privateKeyPath)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -234,6 +235,8 @@ func Test_isAzureUrl(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
type args struct {
|
||||
options baseOption
|
||||
}
|
||||
@@ -308,6 +311,8 @@ func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_azureDownloader_latestCommitID(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"count": 1,
|
||||
|
||||
@@ -21,10 +21,14 @@ func ValidateAutoUpdateSettings(autoUpdate *portainer.AutoUpdateSettings) error
|
||||
return httperrors.NewInvalidPayloadError("invalid Webhook format")
|
||||
}
|
||||
|
||||
if autoUpdate.Interval != "" {
|
||||
if _, err := time.ParseDuration(autoUpdate.Interval); err != nil {
|
||||
return httperrors.NewInvalidPayloadError("invalid Interval format")
|
||||
}
|
||||
if autoUpdate.Interval == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if d, err := time.ParseDuration(autoUpdate.Interval); err != nil {
|
||||
return httperrors.NewInvalidPayloadError("invalid Interval format")
|
||||
} else if d < time.Minute {
|
||||
return httperrors.NewInvalidPayloadError("interval must be at least 1 minute")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -23,6 +23,16 @@ func Test_ValidateAutoUpdate(t *testing.T) {
|
||||
value: &portainer.AutoUpdateSettings{Interval: "1dd2hh3mm"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "short interval value",
|
||||
value: &portainer.AutoUpdateSettings{Interval: "1s"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid webhook without interval",
|
||||
value: &portainer.AutoUpdateSettings{Webhook: "8dce8c2f-9ca1-482b-ad20-271e86536ada"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid auto update",
|
||||
value: &portainer.AutoUpdateSettings{
|
||||
|
||||
@@ -4,10 +4,14 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewService(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
service := NewService(true)
|
||||
require.NotNil(t, service)
|
||||
require.True(t, service.httpsClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) //nolint:forbidigo
|
||||
|
||||
@@ -6,11 +6,14 @@ import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecutePingOperationFailure(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
host := "http://localhost:1"
|
||||
config := portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
|
||||
@@ -26,11 +26,10 @@ func (handler *Handler) logout(w http.ResponseWriter, r *http.Request) *httperro
|
||||
handler.KubernetesTokenCacheManager.RemoveUserFromCache(tokenData.ID)
|
||||
handler.KubernetesClientFactory.ClearUserClientCache(strconv.Itoa(int(tokenData.ID)))
|
||||
logoutcontext.Cancel(tokenData.Token)
|
||||
handler.bouncer.RevokeJWT(tokenData.Token)
|
||||
}
|
||||
|
||||
security.RemoveAuthCookie(w)
|
||||
|
||||
handler.bouncer.RevokeJWT(tokenData.Token)
|
||||
|
||||
return response.Empty(w)
|
||||
}
|
||||
|
||||
55
api/http/handler/auth/logout_test.go
Normal file
55
api/http/handler/auth/logout_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/api/kubernetes/cli"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockBouncer struct {
|
||||
security.BouncerService
|
||||
}
|
||||
|
||||
func NewMockBouncer() *mockBouncer {
|
||||
return &mockBouncer{BouncerService: testhelpers.NewTestRequestBouncer()}
|
||||
}
|
||||
|
||||
func (*mockBouncer) CookieAuthLookup(r *http.Request) (*portainer.TokenData, error) {
|
||||
return &portainer.TokenData{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Token: "valid-token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
h := NewHandler(NewMockBouncer(), nil, nil, nil)
|
||||
h.KubernetesTokenCacheManager = kubernetes.NewTokenCacheManager()
|
||||
k, err := cli.NewClientFactory(nil, nil, nil, "", "", "")
|
||||
require.NoError(t, err)
|
||||
h.KubernetesClientFactory = k
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/auth/logout", nil)
|
||||
|
||||
h.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusNoContent, rr.Code)
|
||||
}
|
||||
|
||||
func TestLogoutNoPanic(t *testing.T) {
|
||||
h := NewHandler(testhelpers.NewTestRequestBouncer(), nil, nil, nil)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/auth/logout", nil)
|
||||
|
||||
h.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusNoContent, rr.Code)
|
||||
}
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"strconv"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
httperrors "github.com/portainer/portainer/api/http/errors"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
@@ -32,31 +35,45 @@ func (handler *Handler) customTemplateInspect(w http.ResponseWriter, r *http.Req
|
||||
return httperror.BadRequest("Invalid Custom template identifier route variable", err)
|
||||
}
|
||||
|
||||
customTemplate, err := handler.DataStore.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID))
|
||||
if handler.DataStore.IsErrObjectNotFound(err) {
|
||||
return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
} else if err != nil {
|
||||
return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
}
|
||||
var customTemplate *portainer.CustomTemplate
|
||||
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
customTemplate, err = tx.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID))
|
||||
if handler.DataStore.IsErrObjectNotFound(err) {
|
||||
return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
} else if err != nil {
|
||||
return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
}
|
||||
|
||||
securityContext, err := security.RetrieveRestrictedRequestContext(r)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve user info from request context", err)
|
||||
}
|
||||
resourceControl, err := tx.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err)
|
||||
}
|
||||
|
||||
resourceControl, err := handler.DataStore.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err)
|
||||
}
|
||||
securityContext, err := security.RetrieveRestrictedRequestContext(r)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve user info from request context", err)
|
||||
}
|
||||
|
||||
canEdit := userCanEditTemplate(customTemplate, securityContext)
|
||||
hasAccess := false
|
||||
|
||||
if resourceControl != nil {
|
||||
customTemplate.ResourceControl = resourceControl
|
||||
|
||||
teamIDs := slicesx.Map(securityContext.UserMemberships, func(m portainer.TeamMembership) portainer.TeamID {
|
||||
return m.TeamID
|
||||
})
|
||||
|
||||
hasAccess = authorization.UserCanAccessResource(securityContext.UserID, teamIDs, resourceControl)
|
||||
|
||||
}
|
||||
|
||||
if canEdit || hasAccess {
|
||||
return nil
|
||||
}
|
||||
|
||||
access := userCanEditTemplate(customTemplate, securityContext)
|
||||
if !access {
|
||||
return httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)
|
||||
}
|
||||
})
|
||||
|
||||
if resourceControl != nil {
|
||||
customTemplate.ResourceControl = resourceControl
|
||||
}
|
||||
|
||||
return response.JSON(w, customTemplate)
|
||||
return response.TxResponse(w, customTemplate, err)
|
||||
}
|
||||
|
||||
100
api/http/handler/customtemplates/customtemplate_inspect_test.go
Normal file
100
api/http/handler/customtemplates/customtemplate_inspect_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package customtemplates
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInspectHandler(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole}))
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 2, Username: "std2", Role: portainer.StandardUserRole}))
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 3, Username: "std3", Role: portainer.StandardUserRole}))
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 4, Username: "std4", Role: portainer.StandardUserRole}))
|
||||
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{ID: 1,
|
||||
UserAccessPolicies: portainer.UserAccessPolicies{
|
||||
2: portainer.AccessPolicy{RoleID: 0},
|
||||
3: portainer.AccessPolicy{RoleID: 0},
|
||||
}}))
|
||||
require.NoError(t, tx.Team().Create(&portainer.Team{ID: 1}))
|
||||
require.NoError(t, tx.TeamMembership().Create(&portainer.TeamMembership{ID: 1, UserID: 3, TeamID: 1, Role: portainer.TeamMember}))
|
||||
|
||||
require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
|
||||
require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 2}))
|
||||
require.NoError(t, tx.ResourceControl().Create(&portainer.ResourceControl{ID: 1, ResourceID: "2", Type: portainer.CustomTemplateResourceControl,
|
||||
UserAccesses: []portainer.UserResourceAccess{{UserID: 2}},
|
||||
TeamAccesses: []portainer.TeamResourceAccess{{TeamID: 1}},
|
||||
}))
|
||||
return nil
|
||||
}))
|
||||
|
||||
handler := NewHandler(testhelpers.NewTestRequestBouncer(), ds, &TestFileService{}, nil)
|
||||
|
||||
test := func(templateID string, restrictedContext *security.RestrictedRequestContext) (*httptest.ResponseRecorder, *httperror.HandlerError) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/custom_templates/"+templateID, nil)
|
||||
r = mux.SetURLVars(r, map[string]string{"id": templateID})
|
||||
ctx := security.StoreRestrictedRequestContext(r, restrictedContext)
|
||||
r = r.WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
return rr, handler.customTemplateInspect(rr, r)
|
||||
}
|
||||
|
||||
t.Run("unknown id should get not found error", func(t *testing.T) {
|
||||
_, r := test("0", &security.RestrictedRequestContext{UserID: 1})
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusNotFound, r.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("admin should access adminonly template", func(t *testing.T) {
|
||||
rr, r := test("1", &security.RestrictedRequestContext{UserID: 1, IsAdmin: true})
|
||||
require.Nil(t, r)
|
||||
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
|
||||
var template portainer.CustomTemplate
|
||||
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
|
||||
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
|
||||
})
|
||||
|
||||
t.Run("std should not access adminonly template", func(t *testing.T) {
|
||||
_, r := test("1", &security.RestrictedRequestContext{UserID: 2})
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("std should access template via direct user access", func(t *testing.T) {
|
||||
rr, r := test("2", &security.RestrictedRequestContext{UserID: 2})
|
||||
require.Nil(t, r)
|
||||
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
|
||||
var template portainer.CustomTemplate
|
||||
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
|
||||
require.Equal(t, portainer.CustomTemplateID(2), template.ID)
|
||||
})
|
||||
|
||||
t.Run("std should access template via team access", func(t *testing.T) {
|
||||
rr, r := test("2", &security.RestrictedRequestContext{UserID: 3, UserMemberships: []portainer.TeamMembership{{ID: 1, UserID: 3, TeamID: 1}}})
|
||||
require.Nil(t, r)
|
||||
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
|
||||
var template portainer.CustomTemplate
|
||||
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
|
||||
require.Equal(t, portainer.CustomTemplateID(2), template.ID)
|
||||
})
|
||||
|
||||
t.Run("std should not access template without access", func(t *testing.T) {
|
||||
_, r := test("2", &security.RestrictedRequestContext{UserID: 4})
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
})
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
|
||||
|
||||
groupsIds := stack.EdgeGroups
|
||||
if payload.EdgeGroups != nil {
|
||||
newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack.ID, payload.EdgeGroups, relatedEndpointIds, relationConfig)
|
||||
newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack, payload.EdgeGroups, relatedEndpointIds, relationConfig)
|
||||
if err != nil {
|
||||
return nil, httperror.InternalServerError("Unable to handle edge groups change", err)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
|
||||
return stack, nil
|
||||
}
|
||||
|
||||
func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) {
|
||||
func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStack *portainer.EdgeStack, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) {
|
||||
newRelatedEnvironmentIDs, err := edge.EdgeStackRelatedEndpoints(newEdgeGroupsIDs, relationConfig.Endpoints, relationConfig.EndpointGroups, relationConfig.EdgeGroups)
|
||||
if err != nil {
|
||||
return nil, nil, errors.WithMessage(err, "Unable to retrieve edge stack related environments from database")
|
||||
@@ -149,13 +149,13 @@ func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edge
|
||||
relatedEnvironmentsToRemove := oldRelatedEnvironmentsSet.Difference(newRelatedEnvironmentsSet)
|
||||
|
||||
if len(relatedEnvironmentsToRemove) > 0 {
|
||||
if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStackID); err != nil {
|
||||
if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStack.ID); err != nil {
|
||||
return nil, nil, errors.WithMessage(err, "Unable to remove edge stack relations from the database")
|
||||
}
|
||||
}
|
||||
|
||||
if len(relatedEnvironmentsToAdd) > 0 {
|
||||
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStackID); err != nil {
|
||||
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStack); err != nil {
|
||||
return nil, nil, errors.WithMessage(err, "Unable to add edge stack relations to the database")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,10 +372,16 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay
|
||||
edgeKey := handler.ReverseTunnelService.GenerateEdgeKey(payload.URL, portainerHost, endpointID)
|
||||
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: portainer.EndpointID(endpointID),
|
||||
Name: payload.Name,
|
||||
URL: portainerHost,
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
ID: portainer.EndpointID(endpointID),
|
||||
Name: payload.Name,
|
||||
URL: portainerHost,
|
||||
Type: func() portainer.EndpointType {
|
||||
// an empty container engine means that the endpoint is a Kubernetes endpoint
|
||||
if payload.ContainerEngine == "" {
|
||||
return portainer.EdgeAgentOnKubernetesEnvironment
|
||||
}
|
||||
return portainer.EdgeAgentOnDockerEnvironment
|
||||
}(),
|
||||
ContainerEngine: payload.ContainerEngine,
|
||||
GroupID: portainer.EndpointGroupID(payload.GroupID),
|
||||
Gpus: payload.Gpus,
|
||||
|
||||
172
api/http/handler/endpoints/endpoint_create_test.go
Normal file
172
api/http/handler/endpoints/endpoint_create_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package endpoints
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/chisel"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// EE-only kubeconfig validation tests removed for CE
|
||||
|
||||
func TestSaveEndpointAndUpdateAuthorizations(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, true, false)
|
||||
|
||||
endpointGroup := &portainer.EndpointGroup{
|
||||
ID: 1,
|
||||
Name: "test-endpoint-group",
|
||||
}
|
||||
|
||||
err := store.EndpointGroup().Create(endpointGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
h := &Handler{
|
||||
DataStore: store,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
endpointType portainer.EndpointType
|
||||
expectRelation bool
|
||||
}{
|
||||
{
|
||||
name: "create azure environment, expect no relation to be created",
|
||||
endpointType: portainer.AzureEnvironment,
|
||||
expectRelation: false,
|
||||
},
|
||||
{
|
||||
name: "create edge agent environment, expect relation to be created",
|
||||
endpointType: portainer.EdgeAgentOnDockerEnvironment,
|
||||
expectRelation: true,
|
||||
},
|
||||
{
|
||||
name: "create kubernetes environment, expect no relation to be created",
|
||||
endpointType: portainer.KubernetesLocalEnvironment,
|
||||
expectRelation: false,
|
||||
},
|
||||
{
|
||||
name: "create kubeconfig environment, expect no relation to be created",
|
||||
endpointType: portainer.AgentOnKubernetesEnvironment,
|
||||
expectRelation: false,
|
||||
},
|
||||
{
|
||||
name: "create agent docker environment, expect no relation to be created",
|
||||
endpointType: portainer.AgentOnDockerEnvironment,
|
||||
expectRelation: false,
|
||||
},
|
||||
{
|
||||
name: "create unsecured environment, expect no relation to be created",
|
||||
endpointType: portainer.DockerEnvironment,
|
||||
expectRelation: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: portainer.EndpointID(store.Endpoint().GetNextIdentifier()),
|
||||
Type: testCase.endpointType,
|
||||
GroupID: portainer.EndpointGroupID(endpointGroup.ID),
|
||||
}
|
||||
|
||||
err := h.saveEndpointAndUpdateAuthorizations(store, endpoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
relation, relationErr := store.EndpointRelation().EndpointRelation(endpoint.ID)
|
||||
if testCase.expectRelation {
|
||||
require.NoError(t, relationErr)
|
||||
require.NotNil(t, relation)
|
||||
} else {
|
||||
require.Error(t, relationErr)
|
||||
require.True(t, store.IsErrObjectNotFound(relationErr))
|
||||
require.Nil(t, relation)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateEndpointFailure(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, false)
|
||||
|
||||
h := NewHandler(testhelpers.NewTestRequestBouncer())
|
||||
h.DataStore = store
|
||||
|
||||
payload := &endpointCreatePayload{
|
||||
Name: "Test Endpoint",
|
||||
EndpointCreationType: agentEnvironment,
|
||||
TLS: true,
|
||||
TLSCertFile: []byte("invalid data"),
|
||||
TLSKeyFile: []byte("invalid data"),
|
||||
}
|
||||
|
||||
endpoint, httpErr := h.createEndpoint(store, payload)
|
||||
require.NotNil(t, httpErr)
|
||||
require.Equal(t, http.StatusInternalServerError, httpErr.StatusCode)
|
||||
require.Nil(t, endpoint)
|
||||
}
|
||||
|
||||
func TestCreateEdgeAgentEndpoint_ContainerEngineMapping(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, false)
|
||||
|
||||
// required group for save flow
|
||||
endpointGroup := &portainer.EndpointGroup{ID: 1, Name: "test-group"}
|
||||
err := store.EndpointGroup().Create(endpointGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
h := &Handler{
|
||||
DataStore: store,
|
||||
ReverseTunnelService: chisel.NewService(store, nil, nil),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
engine string
|
||||
wantType portainer.EndpointType
|
||||
}{
|
||||
{
|
||||
name: "empty engine -> EdgeAgentOnKubernetesEnvironment",
|
||||
engine: "",
|
||||
wantType: portainer.EdgeAgentOnKubernetesEnvironment,
|
||||
},
|
||||
{
|
||||
name: "docker engine -> EdgeAgentOnDockerEnvironment",
|
||||
engine: portainer.ContainerEngineDocker,
|
||||
wantType: portainer.EdgeAgentOnDockerEnvironment,
|
||||
},
|
||||
{
|
||||
name: "podman engine -> EdgeAgentOnDockerEnvironment",
|
||||
engine: portainer.ContainerEnginePodman,
|
||||
wantType: portainer.EdgeAgentOnDockerEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
payload := &endpointCreatePayload{
|
||||
Name: "edge-endpoint",
|
||||
EndpointCreationType: edgeAgentEnvironment,
|
||||
ContainerEngine: tc.engine,
|
||||
GroupID: 1,
|
||||
URL: "https://portainer.example:9443",
|
||||
}
|
||||
|
||||
ep, httpErr := h.createEdgeAgentEndpoint(store, payload)
|
||||
require.Nil(t, httpErr)
|
||||
require.NotNil(t, ep)
|
||||
|
||||
assert.Equal(t, tc.wantType, ep.Type)
|
||||
assert.Equal(t, tc.engine, ep.ContainerEngine)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func (handler *Handler) endpointSnapshots(w http.ResponseWriter, r *http.Request
|
||||
continue
|
||||
}
|
||||
|
||||
endpoint.Status = portainer.EndpointStatusUp
|
||||
latestEndpointReference.Status = portainer.EndpointStatusUp
|
||||
if snapshotError != nil {
|
||||
log.Debug().
|
||||
Str("endpoint", endpoint.Name).
|
||||
@@ -57,7 +57,7 @@ func (handler *Handler) endpointSnapshots(w http.ResponseWriter, r *http.Request
|
||||
Err(snapshotError).
|
||||
Msg("background schedule error (environment snapshot), unable to create snapshot")
|
||||
|
||||
endpoint.Status = portainer.EndpointStatusDown
|
||||
latestEndpointReference.Status = portainer.EndpointStatusDown
|
||||
}
|
||||
|
||||
latestEndpointReference.Agent.Version = endpoint.Agent.Version
|
||||
|
||||
107
api/http/handler/endpoints/endpoint_snapshots_test.go
Normal file
107
api/http/handler/endpoints/endpoint_snapshots_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package endpoints
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_endpointSnapshots(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
endpointID := portainer.EndpointID(123)
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: endpointID,
|
||||
Name: "mock",
|
||||
URL: "http://mock.example/",
|
||||
Status: portainer.EndpointStatusDown, // starts in down state
|
||||
}
|
||||
err := store.Endpoint().Create(endpoint)
|
||||
|
||||
require.NoError(t, err, "error creating environment")
|
||||
|
||||
err = store.User().Create(
|
||||
&portainer.User{
|
||||
Username: "admin",
|
||||
Role: portainer.AdministratorRole,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err, "error creating a user")
|
||||
|
||||
bouncer := testhelpers.NewTestRequestBouncer()
|
||||
|
||||
snapshotService := &mockSnapshotService{
|
||||
snapshotEndpointShouldSucceed: atomic.Bool{},
|
||||
}
|
||||
snapshotService.snapshotEndpointShouldSucceed.Store(true)
|
||||
|
||||
h := NewHandler(bouncer)
|
||||
h.DataStore = store
|
||||
h.SnapshotService = snapshotService
|
||||
|
||||
doPostRequest := func() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/endpoints/snapshot", nil)
|
||||
ctx := security.StoreTokenData(req, &portainer.TokenData{ID: 1, Username: "admin", Role: 1})
|
||||
req = req.WithContext(ctx)
|
||||
testhelpers.AddTestSecurityCookie(req, "Bearer dummytoken")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusNoContent, rr.Code, "Status should be 204")
|
||||
|
||||
_, err := io.ReadAll(rr.Body)
|
||||
require.NoError(t, err, "ReadAll should not return error")
|
||||
}
|
||||
|
||||
doPostRequest()
|
||||
|
||||
// check that the endpoint has been immediately set to up
|
||||
endpoint, err = store.Endpoint().Endpoint(endpointID)
|
||||
require.NoError(t, err, "error getting endpoint")
|
||||
assert.Equal(t, portainer.EndpointStatusUp, endpoint.Status, "endpoint should be up (1) since mock snapshot returned ok")
|
||||
|
||||
// set the mock to return an error
|
||||
snapshotService.snapshotEndpointShouldSucceed.Store(false)
|
||||
doPostRequest()
|
||||
|
||||
// check that the endpoint has been immediately set to down
|
||||
endpoint, err = store.Endpoint().Endpoint(endpointID)
|
||||
require.NoError(t, err, "error getting endpoint")
|
||||
assert.Equal(t, portainer.EndpointStatusDown, endpoint.Status, "endpoint should be down (2) since mock snapshot returned error")
|
||||
}
|
||||
|
||||
var _ portainer.SnapshotService = &mockSnapshotService{}
|
||||
|
||||
type mockSnapshotService struct {
|
||||
snapshotEndpointShouldSucceed atomic.Bool
|
||||
}
|
||||
|
||||
func (s *mockSnapshotService) Start() {
|
||||
}
|
||||
|
||||
func (s *mockSnapshotService) SetSnapshotInterval(snapshotInterval string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mockSnapshotService) SnapshotEndpoint(endpoint *portainer.Endpoint) error {
|
||||
if s.snapshotEndpointShouldSucceed.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("snapshot failed")
|
||||
}
|
||||
|
||||
func (s *mockSnapshotService) FillSnapshotData(endpoint *portainer.Endpoint, includeRaw bool) error {
|
||||
return nil
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func (handler *Handler) filterEndpointsByQuery(
|
||||
return filteredEndpoints, totalAvailableEndpoints, nil
|
||||
}
|
||||
|
||||
func endpointStatusInStackMatchesFilter(stackStatus *portainer.EdgeStackStatusForEnv, envId portainer.EndpointID, statusFilter portainer.EdgeStackStatusType) bool {
|
||||
func endpointStatusInStackMatchesFilter(stackStatus *portainer.EdgeStackStatusForEnv, statusFilter portainer.EdgeStackStatusType) bool {
|
||||
// consider that if the env has no status in the stack it is in Pending state
|
||||
if statusFilter == portainer.EdgeStackStatusPending {
|
||||
return stackStatus == nil || len(stackStatus.Status) == 0
|
||||
@@ -272,55 +272,62 @@ func endpointStatusInStackMatchesFilter(stackStatus *portainer.EdgeStackStatusFo
|
||||
}
|
||||
|
||||
func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId portainer.EdgeStackID, statusFilter *portainer.EdgeStackStatusType, datastore dataservices.DataStore) ([]portainer.Endpoint, error) {
|
||||
stack, err := datastore.EdgeStack().EdgeStack(edgeStackId)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "Unable to retrieve edge stack from the database")
|
||||
}
|
||||
|
||||
envIds := roar.Roar[portainer.EndpointID]{}
|
||||
for _, edgeGroupdId := range stack.EdgeGroups {
|
||||
edgeGroup, err := datastore.EdgeGroup().Read(edgeGroupdId)
|
||||
var filteredEndpoints []portainer.Endpoint
|
||||
if err := datastore.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
stack, err := tx.EdgeStack().EdgeStack(edgeStackId)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database")
|
||||
return errors.WithMessage(err, "Unable to retrieve edge stack from the database")
|
||||
}
|
||||
|
||||
if edgeGroup.Dynamic {
|
||||
endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch)
|
||||
envIds := roar.Roar[portainer.EndpointID]{}
|
||||
for _, edgeGroupId := range stack.EdgeGroups {
|
||||
edgeGroup, err := tx.EdgeGroup().Read(edgeGroupId)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "Unable to retrieve environments and environment groups for Edge group")
|
||||
return errors.WithMessage(err, "Unable to retrieve edge group from the database")
|
||||
}
|
||||
edgeGroup.EndpointIDs = roar.FromSlice(endpointIDs)
|
||||
|
||||
if edgeGroup.Dynamic {
|
||||
endpointIDs, err := edgegroups.GetEndpointsByTags(tx, edgeGroup.TagIDs, edgeGroup.PartialMatch)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Unable to retrieve environments and environment groups for Edge group")
|
||||
}
|
||||
edgeGroup.EndpointIDs = roar.FromSlice(endpointIDs)
|
||||
}
|
||||
|
||||
envIds.Union(edgeGroup.EndpointIDs)
|
||||
}
|
||||
|
||||
envIds.Union(edgeGroup.EndpointIDs)
|
||||
}
|
||||
filteredEnvIds := roar.Roar[portainer.EndpointID]{}
|
||||
filteredEnvIds.Union(envIds)
|
||||
|
||||
if statusFilter != nil {
|
||||
var innerErr error
|
||||
if statusFilter != nil {
|
||||
var innerErr error
|
||||
|
||||
envIds.Iterate(func(envId portainer.EndpointID) bool {
|
||||
edgeStackStatus, err := tx.EdgeStackStatus().Read(edgeStackId, envId)
|
||||
if err != nil && !dataservices.IsErrObjectNotFound(err) {
|
||||
innerErr = errors.WithMessagef(err, "Unable to retrieve edge stack status for environment %d", envId)
|
||||
return false
|
||||
}
|
||||
|
||||
if !endpointStatusInStackMatchesFilter(edgeStackStatus, *statusFilter) {
|
||||
filteredEnvIds.Remove(envId)
|
||||
}
|
||||
|
||||
envIds.Iterate(func(envId portainer.EndpointID) bool {
|
||||
edgeStackStatus, err := datastore.EdgeStackStatus().Read(edgeStackId, envId)
|
||||
if dataservices.IsErrObjectNotFound(err) {
|
||||
return true
|
||||
} else if err != nil {
|
||||
innerErr = errors.WithMessagef(err, "Unable to retrieve edge stack status for environment %d", envId)
|
||||
return false
|
||||
})
|
||||
|
||||
if innerErr != nil {
|
||||
return innerErr
|
||||
}
|
||||
|
||||
if !endpointStatusInStackMatchesFilter(edgeStackStatus, portainer.EndpointID(envId), *statusFilter) {
|
||||
envIds.Remove(envId)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if innerErr != nil {
|
||||
return nil, innerErr
|
||||
}
|
||||
|
||||
filteredEndpoints = filteredEndpointsByIds(endpoints, filteredEnvIds)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filteredEndpoints := filteredEndpointsByIds(endpoints, envIds)
|
||||
|
||||
return filteredEndpoints, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
@@ -304,42 +305,103 @@ func TestFilterEndpointsByEdgeStack(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
endpoints := []portainer.Endpoint{
|
||||
{ID: 1, Name: "Endpoint 1"},
|
||||
{ID: 2, Name: "Endpoint 2"},
|
||||
{ID: 3, Name: "Endpoint 3"},
|
||||
{ID: 1, Name: "Endpoint 1", Type: portainer.EdgeAgentOnDockerEnvironment, UserTrusted: true},
|
||||
{ID: 2, Name: "Endpoint 2", TagIDs: []portainer.TagID{1}, Type: portainer.EdgeAgentOnDockerEnvironment, UserTrusted: true},
|
||||
{ID: 3, Name: "Endpoint 3", TagIDs: []portainer.TagID{1}, Type: portainer.EdgeAgentOnDockerEnvironment, UserTrusted: true},
|
||||
{ID: 4, Name: "Endpoint 4"},
|
||||
}
|
||||
|
||||
edgeStackId := portainer.EdgeStackID(1)
|
||||
require.NoError(t, store.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
require.NoError(t, tx.Tag().Create(&portainer.Tag{ID: 1, Name: "tag", Endpoints: map[portainer.EndpointID]bool{2: true, 3: true}}))
|
||||
|
||||
err := store.EdgeStack().Create(edgeStackId, &portainer.EdgeStack{
|
||||
ID: edgeStackId,
|
||||
Name: "Test Edge Stack",
|
||||
EdgeGroups: []portainer.EdgeGroupID{1, 2},
|
||||
for i := range endpoints {
|
||||
require.NoError(t, tx.Endpoint().Create(&endpoints[i]))
|
||||
}
|
||||
|
||||
require.NoError(t, tx.EdgeStack().Create(edgeStackId, &portainer.EdgeStack{
|
||||
ID: edgeStackId,
|
||||
Name: "Test Edge Stack",
|
||||
EdgeGroups: []portainer.EdgeGroupID{1, 2},
|
||||
}))
|
||||
|
||||
require.NoError(t, tx.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "Edge Group 1",
|
||||
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1}),
|
||||
}))
|
||||
|
||||
require.NoError(t, tx.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
ID: 2,
|
||||
Name: "Edge Group 2",
|
||||
Dynamic: true,
|
||||
TagIDs: []portainer.TagID{1},
|
||||
}))
|
||||
|
||||
require.NoError(t, tx.EdgeStackStatus().Create(edgeStackId, endpoints[0].ID, &portainer.EdgeStackStatusForEnv{
|
||||
Status: []portainer.EdgeStackDeploymentStatus{{Type: portainer.EdgeStackStatusAcknowledged}}}))
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
test := func(status *portainer.EdgeStackStatusType, expected []portainer.Endpoint) {
|
||||
tmp := make([]portainer.Endpoint, len(endpoints))
|
||||
require.Equal(t, 4, copy(tmp, endpoints))
|
||||
es, err := filterEndpointsByEdgeStack(tmp, edgeStackId, status, store)
|
||||
require.NoError(t, err)
|
||||
// validate that the len is the same
|
||||
require.Len(t, es, len(expected))
|
||||
// and that all items are the expected ones
|
||||
for i := range expected {
|
||||
require.Contains(t, es, expected[i])
|
||||
}
|
||||
}
|
||||
|
||||
test(nil, []portainer.Endpoint{endpoints[0], endpoints[1], endpoints[2]})
|
||||
|
||||
status := portainer.EdgeStackStatusPending
|
||||
test(&status, []portainer.Endpoint{endpoints[1], endpoints[2]})
|
||||
|
||||
status = portainer.EdgeStackStatusCompleted
|
||||
test(&status, []portainer.Endpoint{})
|
||||
|
||||
status = portainer.EdgeStackStatusAcknowledged
|
||||
test(&status, []portainer.Endpoint{endpoints[0]}) // that's the only one with an edge stack status in DB
|
||||
}
|
||||
|
||||
func TestErrorsFilterEndpointsByEdgeStack(t *testing.T) {
|
||||
t.Run("must error by edge stack not found", func(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
require.NotNil(t, store)
|
||||
|
||||
_, err := filterEndpointsByEdgeStack([]portainer.Endpoint{}, 1, nil, store)
|
||||
require.Error(t, err)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "Edge Group 1",
|
||||
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1}),
|
||||
t.Run("must error by edge group not found", func(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
require.NotNil(t, store)
|
||||
|
||||
require.NoError(t, store.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
require.NoError(t, tx.EdgeStack().Create(1, &portainer.EdgeStack{ID: 1, Name: "1", EdgeGroups: []portainer.EdgeGroupID{1}}))
|
||||
return nil
|
||||
}))
|
||||
_, err := filterEndpointsByEdgeStack([]portainer.Endpoint{}, 1, nil, store)
|
||||
require.Error(t, err)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
ID: 2,
|
||||
Name: "Edge Group 2",
|
||||
EndpointIDs: roar.FromSlice([]portainer.EndpointID{2, 3}),
|
||||
t.Run("must error by env tag not found", func(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
require.NotNil(t, store)
|
||||
|
||||
require.NoError(t, store.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
require.NoError(t, tx.EdgeStack().Create(1, &portainer.EdgeStack{ID: 1, Name: "1", EdgeGroups: []portainer.EdgeGroupID{1}}))
|
||||
require.NoError(t, tx.EdgeGroup().Create(&portainer.EdgeGroup{ID: 1, Name: "edge group", Dynamic: true, TagIDs: []portainer.TagID{1}}))
|
||||
return nil
|
||||
}))
|
||||
_, err := filterEndpointsByEdgeStack([]portainer.Endpoint{}, 1, nil, store)
|
||||
require.Error(t, err)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
es, err := filterEndpointsByEdgeStack(endpoints, edgeStackId, nil, store)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, es, 3)
|
||||
require.Contains(t, es, endpoints[0]) // Endpoint 1
|
||||
require.Contains(t, es, endpoints[1]) // Endpoint 2
|
||||
require.Contains(t, es, endpoints[2]) // Endpoint 3
|
||||
require.NotContains(t, es, endpoints[3]) // Endpoint 4
|
||||
}
|
||||
|
||||
func TestFilterEndpointsByEdgeGroup(t *testing.T) {
|
||||
|
||||
@@ -55,6 +55,10 @@ func (handler *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if r.RequestURI == "/" || strings.HasSuffix(r.RequestURI, ".html") {
|
||||
w.Header().Set("Permissions-Policy", strings.Join(permissions, ","))
|
||||
}
|
||||
|
||||
if !isHTML(r.Header["Accept"]) {
|
||||
w.Header().Set("Cache-Control", "max-age=31536000")
|
||||
} else {
|
||||
|
||||
70
api/http/handler/file/handler_test.go
Normal file
70
api/http/handler/file/handler_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package file_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/http/handler/file"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalServe(t *testing.T) {
|
||||
handler := file.NewHandler("", false, func() bool { return false })
|
||||
require.NotNil(t, handler)
|
||||
|
||||
request := func(path string) (*http.Request, *httptest.ResponseRecorder) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
handler.ServeHTTP(rr, req)
|
||||
return req, rr
|
||||
}
|
||||
|
||||
_, rr := request("/timeout.html")
|
||||
require.Equal(t, http.StatusTemporaryRedirect, rr.Result().StatusCode)
|
||||
loc, err := rr.Result().Location()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loc)
|
||||
require.Equal(t, "/", loc.Path)
|
||||
|
||||
_, rr = request("/")
|
||||
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
|
||||
}
|
||||
|
||||
func TestPermissionsPolicyHeader(t *testing.T) {
|
||||
handler := file.NewHandler("", false, func() bool { return false })
|
||||
require.NotNil(t, handler)
|
||||
|
||||
test := func(path string, exist bool) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, exist, rr.Result().Header.Get("Permissions-Policy") != "")
|
||||
}
|
||||
|
||||
test("/", true)
|
||||
test("/index.html", true)
|
||||
test("/api", false)
|
||||
test("/an/image.png", false)
|
||||
}
|
||||
|
||||
func TestRedirectInstanceDisabled(t *testing.T) {
|
||||
handler := file.NewHandler("", false, func() bool { return true })
|
||||
require.NotNil(t, handler)
|
||||
|
||||
test := func(path string) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusTemporaryRedirect, rr.Result().StatusCode)
|
||||
loc, err := rr.Result().Location()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loc)
|
||||
require.Equal(t, "/timeout.html", loc.Path)
|
||||
}
|
||||
|
||||
test("/")
|
||||
test("/index.html")
|
||||
}
|
||||
91
api/http/handler/file/permissions_list.go
Normal file
91
api/http/handler/file/permissions_list.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package file
|
||||
|
||||
var permissions = []string{
|
||||
"accelerometer=()",
|
||||
"ambient-light-sensor=()",
|
||||
"attribution-reporting=()",
|
||||
"autoplay=()",
|
||||
"battery=()",
|
||||
"browsing-topics=()",
|
||||
"camera=()",
|
||||
"captured-surface-control=()",
|
||||
"ch-device-memory=()",
|
||||
"ch-downlink=()",
|
||||
"ch-dpr=()",
|
||||
"ch-ect=()",
|
||||
"ch-prefers-color-scheme=()",
|
||||
"ch-prefers-reduced-motion=()",
|
||||
"ch-prefers-reduced-transparency=()",
|
||||
"ch-rtt=()",
|
||||
"ch-save-data=()",
|
||||
"ch-ua=()",
|
||||
"ch-ua-arch=()",
|
||||
"ch-ua-bitness=()",
|
||||
"ch-ua-form-factors=()",
|
||||
"ch-ua-full-version=()",
|
||||
"ch-ua-full-version-list=()",
|
||||
"ch-ua-mobile=()",
|
||||
"ch-ua-model=()",
|
||||
"ch-ua-platform=()",
|
||||
"ch-ua-platform-version=()",
|
||||
"ch-ua-wow64=()",
|
||||
"ch-viewport-height=()",
|
||||
"ch-viewport-width=()",
|
||||
"ch-width=()",
|
||||
"compute-pressure=()",
|
||||
"conversion-measurement=()",
|
||||
"cross-origin-isolated=()",
|
||||
"deferred-fetch=()",
|
||||
"deferred-fetch-minimal=()",
|
||||
"display-capture=()",
|
||||
"document-domain=()",
|
||||
"encrypted-media=()",
|
||||
"execution-while-not-rendered=()",
|
||||
"execution-while-out-of-viewport=()",
|
||||
"focus-without-user-activation=()",
|
||||
"fullscreen=()",
|
||||
"gamepad=()",
|
||||
"geolocation=()",
|
||||
"gyroscope=()",
|
||||
"hid=()",
|
||||
"identity-credentials-get=()",
|
||||
"idle-detection=()",
|
||||
"interest-cohort=()",
|
||||
"join-ad-interest-group=()",
|
||||
"keyboard-map=()",
|
||||
"language-detector=()",
|
||||
"local-fonts=()",
|
||||
"magnetometer=()",
|
||||
"microphone=()",
|
||||
"midi=()",
|
||||
"navigation-override=()",
|
||||
"otp-credentials=()",
|
||||
"payment=()",
|
||||
"picture-in-picture=()",
|
||||
"private-aggregation=()",
|
||||
"private-state-token-issuance=()",
|
||||
"private-state-token-redemption=()",
|
||||
"publickey-credentials-create=()",
|
||||
"publickey-credentials-get=()",
|
||||
"rewriter=()",
|
||||
"run-ad-auction=()",
|
||||
"screen-wake-lock=()",
|
||||
"serial=()",
|
||||
"shared-storage=()",
|
||||
"shared-storage-select-url=()",
|
||||
"speaker-selection=()",
|
||||
"storage-access=()",
|
||||
"summarizer=()",
|
||||
"sync-script=()",
|
||||
"sync-xhr=()",
|
||||
"translator=()",
|
||||
"trust-token-redemption=()",
|
||||
"unload=()",
|
||||
"usb=()",
|
||||
"vertical-scroll=()",
|
||||
"web-share=()",
|
||||
"window-management=()",
|
||||
"window-placement=()",
|
||||
"writer=()",
|
||||
"xr-spatial-tracking=()",
|
||||
}
|
||||
@@ -81,7 +81,7 @@ type Handler struct {
|
||||
}
|
||||
|
||||
// @title PortainerCE API
|
||||
// @version 2.33.0-rc1
|
||||
// @version 2.33.4
|
||||
// @description.markdown api-description.md
|
||||
// @termsOfService
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/pkg/libhelm/test"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
package registries
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_registryCreatePayload_Validate(t *testing.T) {
|
||||
@@ -43,3 +52,46 @@ func Test_registryCreatePayload_Validate(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandler_registryCreate(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
payload := registryCreatePayload{
|
||||
Name: "Test registry",
|
||||
Type: portainer.ProGetRegistry,
|
||||
URL: "http://example.com",
|
||||
BaseURL: "http://example.com",
|
||||
Authentication: false,
|
||||
Username: "username",
|
||||
Password: "password",
|
||||
Gitlab: portainer.GitlabRegistryData{},
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(payloadBytes))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
restrictedContext := &security.RestrictedRequestContext{IsAdmin: true, UserID: 1}
|
||||
|
||||
ctx := security.StoreRestrictedRequestContext(r, restrictedContext)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := NewHandler(testhelpers.NewTestRequestBouncer())
|
||||
handler.DataStore = store
|
||||
|
||||
handlerError := handler.registryCreate(w, r)
|
||||
require.Nil(t, handlerError)
|
||||
|
||||
registry := portainer.Registry{}
|
||||
err = json.NewDecoder(w.Body).Decode(®istry)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, payload.Name, registry.Name)
|
||||
assert.Equal(t, payload.Type, registry.Type)
|
||||
assert.Equal(t, payload.URL, registry.URL)
|
||||
assert.Equal(t, payload.BaseURL, registry.BaseURL)
|
||||
assert.Equal(t, payload.Authentication, registry.Authentication)
|
||||
assert.Equal(t, payload.Username, registry.Username)
|
||||
assert.Empty(t, registry.Password)
|
||||
}
|
||||
|
||||
@@ -177,6 +177,8 @@ func (handler *Handler) registryUpdate(w http.ResponseWriter, r *http.Request) *
|
||||
return httperror.InternalServerError("Unable to persist registry changes inside the database", err)
|
||||
}
|
||||
|
||||
hideFields(registry, true)
|
||||
|
||||
return response.JSON(w, registry)
|
||||
}
|
||||
|
||||
|
||||
68
api/http/handler/registries/registry_update_test.go
Normal file
68
api/http/handler/registries/registry_update_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package registries
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func ptr[T any](i T) *T { return &i }
|
||||
|
||||
func TestHandler_registryUpdate(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
registry := &portainer.Registry{Type: portainer.ProGetRegistry}
|
||||
|
||||
err := store.Registry().Create(registry)
|
||||
require.NoError(t, err)
|
||||
|
||||
payload := registryUpdatePayload{
|
||||
Name: ptr("Updated test registry"),
|
||||
URL: ptr("http://example.org/feed"),
|
||||
BaseURL: ptr("http://example.org"),
|
||||
Authentication: ptr(true),
|
||||
Username: ptr("username"),
|
||||
Password: ptr("password"),
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPut, "/registries/1", bytes.NewReader(payloadBytes))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
restrictedContext := &security.RestrictedRequestContext{IsAdmin: true, UserID: 1}
|
||||
|
||||
ctx := security.StoreRestrictedRequestContext(r, restrictedContext)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler := NewHandler(testhelpers.NewTestRequestBouncer())
|
||||
handler.DataStore = store
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
updatedRegistry := portainer.Registry{}
|
||||
err = json.NewDecoder(w.Body).Decode(&updatedRegistry)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Registry type should remain intact
|
||||
assert.Equal(t, registry.Type, updatedRegistry.Type)
|
||||
|
||||
assert.Equal(t, *payload.Name, updatedRegistry.Name)
|
||||
assert.Equal(t, *payload.URL, updatedRegistry.URL)
|
||||
assert.Equal(t, *payload.BaseURL, updatedRegistry.BaseURL)
|
||||
assert.Equal(t, *payload.Authentication, updatedRegistry.Authentication)
|
||||
assert.Equal(t, *payload.Username, updatedRegistry.Username)
|
||||
assert.Empty(t, updatedRegistry.Password)
|
||||
}
|
||||
@@ -73,6 +73,14 @@ func (handler *Handler) stackUpdateGit(w http.ResponseWriter, r *http.Request) *
|
||||
return httperror.InternalServerError(msg, errors.New(msg))
|
||||
}
|
||||
|
||||
if payload.AutoUpdate != nil && payload.AutoUpdate.Webhook != "" &&
|
||||
(stack.AutoUpdate == nil ||
|
||||
(stack.AutoUpdate != nil && stack.AutoUpdate.Webhook != payload.AutoUpdate.Webhook)) {
|
||||
if isUnique, err := handler.checkUniqueWebhookID(payload.AutoUpdate.Webhook); !isUnique || err != nil {
|
||||
return httperror.Conflict("Webhook ID already exists", errors.New("webhook ID already exists"))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this is a work-around for stacks created with Portainer version >= 1.17.1
|
||||
// The EndpointID property is not available for these stacks, this API environment(endpoint)
|
||||
// can use the optional EndpointID query parameter to associate a valid environment(endpoint) identifier to the stack.
|
||||
|
||||
78
api/http/handler/stacks/stack_update_git_test.go
Normal file
78
api/http/handler/stacks/stack_update_git_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package stacks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStackUpdateGitWebhookUniqueness(t *testing.T) {
|
||||
webhook, err := uuid.NewV4()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: 123,
|
||||
Name: "endpoint1",
|
||||
Type: portainer.DockerEnvironment,
|
||||
}
|
||||
err = store.Endpoint().Create(endpoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
stack1 := portainer.Stack{
|
||||
ID: 456,
|
||||
EndpointID: endpoint.ID,
|
||||
AutoUpdate: &portainer.AutoUpdateSettings{
|
||||
Webhook: webhook.String(),
|
||||
},
|
||||
GitConfig: &gittypes.RepoConfig{
|
||||
URL: "https://github.com/portainer/portainer.git",
|
||||
},
|
||||
}
|
||||
|
||||
err = store.Stack().Create(&stack1)
|
||||
require.NoError(t, err)
|
||||
|
||||
stack2 := stack1
|
||||
stack2.ID++
|
||||
stack2.AutoUpdate = nil
|
||||
|
||||
err = store.Stack().Create(&stack2)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := NewHandler(testhelpers.NewTestRequestBouncer())
|
||||
handler.DataStore = store
|
||||
|
||||
payload := &stackGitUpdatePayload{
|
||||
AutoUpdate: &portainer.AutoUpdateSettings{
|
||||
Webhook: webhook.String(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonPayload, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
url := "/stacks/" + strconv.Itoa(int(stack2.ID)) + "/git?endpointId=" + strconv.Itoa(int(endpoint.ID))
|
||||
req := httptest.NewRequest(http.MethodPost, url, bytes.NewReader(jsonPayload))
|
||||
|
||||
rrc := &security.RestrictedRequestContext{}
|
||||
req = req.WithContext(security.StoreRestrictedRequestContext(req, rrc))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rr.Code)
|
||||
}
|
||||
@@ -7,11 +7,14 @@ import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInitDial(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer srv.Close()
|
||||
|
||||
|
||||
@@ -36,13 +36,15 @@ type K8sApplication struct {
|
||||
Kind string `json:"Kind,omitempty"`
|
||||
MatchLabels map[string]string `json:"MatchLabels,omitempty"`
|
||||
Labels map[string]string `json:"Labels,omitempty"`
|
||||
Annotations map[string]string `json:"Annotations,omitempty"`
|
||||
Resource K8sApplicationResource `json:"Resource,omitempty"`
|
||||
HorizontalPodAutoscaler *autoscalingv2.HorizontalPodAutoscaler `json:"HorizontalPodAutoscaler,omitempty"`
|
||||
CustomResourceMetadata CustomResourceMetadata `json:"CustomResourceMetadata,omitempty"`
|
||||
}
|
||||
|
||||
type Metadata struct {
|
||||
Labels map[string]string `json:"labels"`
|
||||
Labels map[string]string `json:"labels"`
|
||||
Annotations map[string]string `json:"annotations"`
|
||||
}
|
||||
|
||||
type CustomResourceMetadata struct {
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"strings"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/docker/consts"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/utils"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
"github.com/portainer/portainer/api/stacks/stackutils"
|
||||
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -17,9 +19,6 @@ const (
|
||||
resourceLabelForPortainerTeamResourceControl = "io.portainer.accesscontrol.teams"
|
||||
resourceLabelForPortainerUserResourceControl = "io.portainer.accesscontrol.users"
|
||||
resourceLabelForPortainerPublicResourceControl = "io.portainer.accesscontrol.public"
|
||||
resourceLabelForDockerSwarmStackName = "com.docker.stack.namespace"
|
||||
resourceLabelForDockerServiceID = "com.docker.swarm.service.id"
|
||||
resourceLabelForDockerComposeStackName = "com.docker.compose.project"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -123,13 +122,7 @@ func (transport *Transport) createPrivateResourceControl(resourceIdentifier stri
|
||||
return resourceControl, nil
|
||||
}
|
||||
|
||||
func (transport *Transport) getInheritedResourceControlFromServiceOrStack(resourceIdentifier, nodeName string, resourceType portainer.ResourceControlType, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
|
||||
client, err := transport.dockerClientFactory.CreateClient(transport.endpoint, nodeName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
func (transport *Transport) getInheritedResourceControlFromServiceOrStack(client *client.Client, resourceIdentifier string, resourceType portainer.ResourceControlType, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
|
||||
switch resourceType {
|
||||
case portainer.ContainerResourceControl:
|
||||
return getInheritedResourceControlFromContainerLabels(client, transport.endpoint.ID, resourceIdentifier, resourceControls)
|
||||
@@ -295,8 +288,8 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if resourceLabelsObject[resourceLabelForDockerServiceID] != nil {
|
||||
inheritedServiceIdentifier := resourceLabelsObject[resourceLabelForDockerServiceID].(string)
|
||||
if resourceLabelsObject[consts.SwarmServiceIDLabel] != nil {
|
||||
inheritedServiceIdentifier := resourceLabelsObject[consts.SwarmServiceIDLabel].(string)
|
||||
resourceControl = authorization.GetResourceControlByResourceIDAndType(inheritedServiceIdentifier, portainer.ServiceResourceControl, resourceControls)
|
||||
|
||||
if resourceControl != nil {
|
||||
@@ -304,8 +297,8 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou
|
||||
}
|
||||
}
|
||||
|
||||
if resourceLabelsObject[resourceLabelForDockerSwarmStackName] != nil {
|
||||
stackName := resourceLabelsObject[resourceLabelForDockerSwarmStackName].(string)
|
||||
if resourceLabelsObject[consts.SwarmStackNameLabel] != nil {
|
||||
stackName := resourceLabelsObject[consts.SwarmStackNameLabel].(string)
|
||||
stackResourceID := stackutils.ResourceControlID(transport.endpoint.ID, stackName)
|
||||
resourceControl = authorization.GetResourceControlByResourceIDAndType(stackResourceID, portainer.StackResourceControl, resourceControls)
|
||||
|
||||
@@ -314,8 +307,8 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou
|
||||
}
|
||||
}
|
||||
|
||||
if resourceLabelsObject[resourceLabelForDockerComposeStackName] != nil {
|
||||
stackName := resourceLabelsObject[resourceLabelForDockerComposeStackName].(string)
|
||||
if resourceLabelsObject[consts.ComposeStackNameLabel] != nil {
|
||||
stackName := resourceLabelsObject[consts.ComposeStackNameLabel].(string)
|
||||
stackResourceID := stackutils.ResourceControlID(transport.endpoint.ID, stackName)
|
||||
resourceControl = authorization.GetResourceControlByResourceIDAndType(stackResourceID, portainer.StackResourceControl, resourceControls)
|
||||
|
||||
@@ -328,14 +321,14 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou
|
||||
}
|
||||
|
||||
func getStackResourceIDFromLabels(resourceLabelsObject map[string]string, endpointID portainer.EndpointID) string {
|
||||
if resourceLabelsObject[resourceLabelForDockerSwarmStackName] != "" {
|
||||
stackName := resourceLabelsObject[resourceLabelForDockerSwarmStackName]
|
||||
if resourceLabelsObject[consts.SwarmStackNameLabel] != "" {
|
||||
stackName := resourceLabelsObject[consts.SwarmStackNameLabel]
|
||||
|
||||
return stackutils.ResourceControlID(endpointID, stackName)
|
||||
}
|
||||
|
||||
if resourceLabelsObject[resourceLabelForDockerComposeStackName] != "" {
|
||||
stackName := resourceLabelsObject[resourceLabelForDockerComposeStackName]
|
||||
if resourceLabelsObject[consts.ComposeStackNameLabel] != "" {
|
||||
stackName := resourceLabelsObject[consts.ComposeStackNameLabel]
|
||||
|
||||
return stackutils.ResourceControlID(endpointID, stackName)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/docker/consts"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/utils"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
@@ -34,7 +35,7 @@ func getInheritedResourceControlFromContainerLabels(dockerClient *client.Client,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceName := container.Config.Labels[resourceLabelForDockerServiceID]
|
||||
serviceName := container.Config.Labels[consts.SwarmServiceIDLabel]
|
||||
if serviceName != "" {
|
||||
serviceResourceControl := authorization.GetResourceControlByResourceIDAndType(serviceName, portainer.ServiceResourceControl, resourceControls)
|
||||
if serviceResourceControl != nil {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/utils"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/swarm"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/segmentio/encoding/json"
|
||||
)
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
const serviceObjectIdentifier = "ID"
|
||||
|
||||
func getInheritedResourceControlFromServiceLabels(dockerClient *client.Client, endpointID portainer.EndpointID, serviceID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
|
||||
service, _, err := dockerClient.ServiceInspectWithRaw(context.Background(), serviceID, types.ServiceInspectOptions{})
|
||||
service, _, err := dockerClient.ServiceInspectWithRaw(context.Background(), serviceID, swarm.ServiceInspectOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package docker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,12 +16,15 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/docker/client"
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/utils"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
|
||||
dockerclient "github.com/portainer/portainer/api/docker/client"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/api/types/swarm"
|
||||
dockerclient "github.com/docker/docker/client"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
)
|
||||
@@ -36,7 +40,7 @@ type (
|
||||
dataStore dataservices.DataStore
|
||||
signatureService portainer.DigitalSignatureService
|
||||
reverseTunnelService portainer.ReverseTunnelService
|
||||
dockerClientFactory *dockerclient.ClientFactory
|
||||
dockerClientFactory *client.ClientFactory
|
||||
gitService portainer.GitService
|
||||
snapshotService portainer.SnapshotService
|
||||
dockerID string
|
||||
@@ -49,7 +53,7 @@ type (
|
||||
DataStore dataservices.DataStore
|
||||
SignatureService portainer.DigitalSignatureService
|
||||
ReverseTunnelService portainer.ReverseTunnelService
|
||||
DockerClientFactory *dockerclient.ClientFactory
|
||||
DockerClientFactory *client.ClientFactory
|
||||
}
|
||||
|
||||
restrictedDockerOperationContext struct {
|
||||
@@ -107,6 +111,9 @@ var prefixProxyFuncMap = map[string]func(*Transport, *http.Request, string) (*ht
|
||||
// ProxyDockerRequest intercepts a Docker API request and apply logic based
|
||||
// on the requested operation.
|
||||
func (transport *Transport) ProxyDockerRequest(request *http.Request) (*http.Response, error) {
|
||||
// from : /v1.41/containers/{id}/json
|
||||
// or : /containers/{id}/json
|
||||
// to : /containers/{id}/json
|
||||
unversionedPath := apiVersionRe.ReplaceAllString(request.URL.Path, "")
|
||||
|
||||
if transport.endpoint.Type == portainer.AgentOnDockerEnvironment || transport.endpoint.Type == portainer.EdgeAgentOnDockerEnvironment {
|
||||
@@ -119,6 +126,10 @@ func (transport *Transport) ProxyDockerRequest(request *http.Request) (*http.Res
|
||||
request.Header.Set(portainer.PortainerAgentSignatureHeader, signature)
|
||||
}
|
||||
|
||||
// from : /containers/{id}/json
|
||||
// trim to : containers/{id}/json
|
||||
// pick : [ containers, {id}, json ][0]
|
||||
// prefix : containers
|
||||
prefix := strings.Split(strings.TrimPrefix(unversionedPath, "/"), "/")[0]
|
||||
|
||||
if proxyFunc := prefixProxyFuncMap[prefix]; proxyFunc != nil {
|
||||
@@ -215,9 +226,10 @@ func (transport *Transport) proxyConfigRequest(request *http.Request, unversione
|
||||
// Assume /configs/{id}
|
||||
configID := path.Base(requestPath)
|
||||
|
||||
if request.Method == http.MethodGet {
|
||||
switch request.Method {
|
||||
case http.MethodGet:
|
||||
return transport.rewriteOperation(request, transport.configInspectOperation)
|
||||
} else if request.Method == http.MethodDelete {
|
||||
case http.MethodDelete:
|
||||
return transport.executeGenericResourceDeletionOperation(request, configID, configID, portainer.ConfigResourceControl)
|
||||
}
|
||||
|
||||
@@ -248,7 +260,6 @@ func (transport *Transport) proxyContainerRequest(request *http.Request, unversi
|
||||
if action == "json" {
|
||||
return transport.rewriteOperation(request, transport.containerInspectOperation)
|
||||
}
|
||||
|
||||
return transport.restrictedResourceOperation(request, containerID, containerID, portainer.ContainerResourceControl, false)
|
||||
} else if match, _ := path.Match("/containers/*", requestPath); match {
|
||||
// Handle /containers/{id} requests
|
||||
@@ -280,7 +291,10 @@ func (transport *Transport) proxyServiceRequest(request *http.Request, unversion
|
||||
if match, _ := path.Match("/services/*/*", requestPath); match {
|
||||
// Handle /services/{id}/{action} requests
|
||||
serviceID := path.Base(path.Dir(requestPath))
|
||||
transport.decorateRegistryAuthenticationHeader(request)
|
||||
|
||||
if err := transport.decorateRegistryAuthenticationHeader(request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return transport.restrictedResourceOperation(request, serviceID, serviceID, portainer.ServiceResourceControl, false)
|
||||
} else if match, _ := path.Match("/services/*", requestPath); match {
|
||||
@@ -320,28 +334,38 @@ func (transport *Transport) proxyVolumeRequest(request *http.Request, unversione
|
||||
}
|
||||
}
|
||||
|
||||
func match(requestPath string, pattern string) bool {
|
||||
ok, err := path.Match(pattern, requestPath)
|
||||
return err == nil && ok
|
||||
}
|
||||
|
||||
func (transport *Transport) proxyNetworkRequest(request *http.Request, unversionedPath string) (*http.Response, error) {
|
||||
requestPath := unversionedPath
|
||||
|
||||
switch requestPath {
|
||||
case "/networks/create":
|
||||
switch {
|
||||
case requestPath == "/networks/create":
|
||||
return transport.decorateGenericResourceCreationOperation(request, networkObjectIdentifier, portainer.NetworkResourceControl)
|
||||
|
||||
case "/networks":
|
||||
case requestPath == "/networks":
|
||||
return transport.rewriteOperation(request, transport.networkListOperation)
|
||||
|
||||
default:
|
||||
// Assume /networks/{id}
|
||||
networkID := path.Base(requestPath)
|
||||
|
||||
if request.Method == http.MethodGet {
|
||||
return transport.rewriteOperation(request, transport.networkInspectOperation)
|
||||
} else if request.Method == http.MethodDelete {
|
||||
return transport.executeGenericResourceDeletionOperation(request, networkID, networkID, portainer.NetworkResourceControl)
|
||||
}
|
||||
case request.Method == http.MethodPost && match(requestPath, "/networks/*/connect"),
|
||||
request.Method == http.MethodPost && match(requestPath, "/networks/*/disconnect"):
|
||||
|
||||
networkID := path.Base(path.Dir(requestPath))
|
||||
return transport.restrictedResourceOperation(request, networkID, networkID, portainer.NetworkResourceControl, false)
|
||||
|
||||
case request.Method == http.MethodGet && match(requestPath, "/networks/*"):
|
||||
return transport.rewriteOperation(request, transport.networkInspectOperation)
|
||||
|
||||
case request.Method == http.MethodDelete && match(requestPath, "/networks/*"):
|
||||
networkID := path.Base(requestPath)
|
||||
return transport.executeGenericResourceDeletionOperation(request, networkID, networkID, portainer.NetworkResourceControl)
|
||||
}
|
||||
|
||||
// Assume /networks/{id}
|
||||
networkID := path.Base(requestPath)
|
||||
return transport.restrictedResourceOperation(request, networkID, networkID, portainer.NetworkResourceControl, false)
|
||||
}
|
||||
|
||||
func (transport *Transport) proxySecretRequest(request *http.Request, unversionedPath string) (*http.Response, error) {
|
||||
@@ -358,9 +382,10 @@ func (transport *Transport) proxySecretRequest(request *http.Request, unversione
|
||||
// Assume /secrets/{id}
|
||||
secretID := path.Base(requestPath)
|
||||
|
||||
if request.Method == http.MethodGet {
|
||||
switch request.Method {
|
||||
case http.MethodGet:
|
||||
return transport.rewriteOperation(request, transport.secretInspectOperation)
|
||||
} else if request.Method == http.MethodDelete {
|
||||
case http.MethodDelete:
|
||||
return transport.executeGenericResourceDeletionOperation(request, secretID, secretID, portainer.SecretResourceControl)
|
||||
}
|
||||
|
||||
@@ -413,7 +438,6 @@ func (transport *Transport) proxyBuildRequest(request *http.Request, _ string) (
|
||||
|
||||
func (transport *Transport) updateDefaultGitBranch(request *http.Request) error {
|
||||
remote := request.URL.Query().Get("remote")
|
||||
|
||||
if !strings.HasSuffix(remote, ".git") {
|
||||
return nil
|
||||
}
|
||||
@@ -549,32 +573,101 @@ func (transport *Transport) restrictedResourceOperation(request *http.Request, r
|
||||
}
|
||||
|
||||
resourceControl := authorization.GetResourceControlByResourceIDAndType(resourceID, resourceType, resourceControls)
|
||||
if resourceControl == nil {
|
||||
agentTargetHeader := request.Header.Get(portainer.PortainerAgentTargetHeader)
|
||||
|
||||
if dockerResourceID == "" {
|
||||
dockerResourceID = resourceID
|
||||
}
|
||||
|
||||
// This resource was created outside of portainer,
|
||||
// is part of a Docker service or part of a Docker Swarm/Compose stack.
|
||||
inheritedResourceControl, err := transport.getInheritedResourceControlFromServiceOrStack(dockerResourceID, agentTargetHeader, resourceType, resourceControls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if inheritedResourceControl == nil || !authorization.UserCanAccessResource(tokenData.ID, userTeamIDs, inheritedResourceControl) {
|
||||
if resourceControl != nil {
|
||||
if !authorization.UserCanAccessResource(tokenData.ID, userTeamIDs, resourceControl) {
|
||||
return utils.WriteAccessDeniedResponse()
|
||||
}
|
||||
return transport.executeDockerRequest(request)
|
||||
}
|
||||
|
||||
if resourceControl != nil && !authorization.UserCanAccessResource(tokenData.ID, userTeamIDs, resourceControl) {
|
||||
client, err := transport.dockerClientFactory.CreateClient(transport.endpoint, request.Header.Get(portainer.PortainerAgentTargetHeader), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// the resourceID may be the resource name (as it's a valid proxy call to use the name and not the UUID)
|
||||
// so get the real resource ID and retry with it
|
||||
resourceID, err = getRealResourceID(client, resourceType, resourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resourceControl = authorization.GetResourceControlByResourceIDAndType(resourceID, resourceType, resourceControls)
|
||||
if resourceControl != nil {
|
||||
if !authorization.UserCanAccessResource(tokenData.ID, userTeamIDs, resourceControl) {
|
||||
return utils.WriteAccessDeniedResponse()
|
||||
}
|
||||
return transport.executeDockerRequest(request)
|
||||
}
|
||||
|
||||
// If we still can't find the RC by provided ID or "real" (docker-extracted) ID
|
||||
// it means this resource was created outside of portainer,
|
||||
// is part of a Docker service or part of a Docker Swarm/Compose stack.
|
||||
if dockerResourceID == "" {
|
||||
dockerResourceID = resourceID
|
||||
}
|
||||
inheritedResourceControl, err := transport.getInheritedResourceControlFromServiceOrStack(client, dockerResourceID, resourceType, resourceControls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if inheritedResourceControl == nil || !authorization.UserCanAccessResource(tokenData.ID, userTeamIDs, inheritedResourceControl) {
|
||||
return utils.WriteAccessDeniedResponse()
|
||||
}
|
||||
|
||||
return transport.executeDockerRequest(request)
|
||||
}
|
||||
|
||||
func getRealResourceID(client *dockerclient.Client, resourceType portainer.ResourceControlType, resourceId string) (string, error) {
|
||||
switch resourceType {
|
||||
case portainer.NetworkResourceControl:
|
||||
network, err := client.NetworkInspect(context.Background(), resourceId, network.InspectOptions{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return network.ID, nil
|
||||
|
||||
case portainer.ContainerResourceControl:
|
||||
container, err := client.ContainerInspect(context.Background(), resourceId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return container.ID, nil
|
||||
|
||||
case portainer.VolumeResourceControl:
|
||||
// volumes don't have an UUID and their UACresourceID has a particular construct that makes them unique
|
||||
// e.g. fmt.Sprintf("%s_%s", volumeName, dockerID)
|
||||
// see transport.getVolumeResourceID() / FetchDockerID()
|
||||
// FetchDockerID fetches info.Swarm.Cluster.ID if environment(endpoint) is swarm and info.ID otherwise
|
||||
// So: return empty ID but without error
|
||||
return "", nil
|
||||
|
||||
case portainer.ServiceResourceControl:
|
||||
service, _, err := client.ServiceInspectWithRaw(context.Background(), resourceId, swarm.ServiceInspectOptions{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return service.ID, nil
|
||||
|
||||
case portainer.ConfigResourceControl:
|
||||
config, _, err := client.ConfigInspectWithRaw(context.Background(), resourceId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return config.ID, nil
|
||||
|
||||
case portainer.SecretResourceControl:
|
||||
secret, _, err := client.SecretInspectWithRaw(context.Background(), resourceId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return secret.ID, nil
|
||||
|
||||
}
|
||||
return "", fmt.Errorf("Unknown resource type %v", resourceType)
|
||||
}
|
||||
|
||||
// rewriteOperationWithLabelFiltering will create a new operation context with data that will be used
|
||||
// to decorate the original request's response as well as retrieve all the black listed labels
|
||||
// to filter the resources.
|
||||
|
||||
@@ -6,9 +6,19 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/api/types/swarm"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransport_updateDefaultGitBranch(t *testing.T) {
|
||||
@@ -21,7 +31,6 @@ func TestTransport_updateDefaultGitBranch(t *testing.T) {
|
||||
}
|
||||
|
||||
commitId := "my-latest-commit-id"
|
||||
|
||||
defaultFields := fields{
|
||||
gitService: testhelpers.NewGitService(nil, commitId),
|
||||
}
|
||||
@@ -67,3 +76,332 @@ func TestTransport_updateDefaultGitBranch(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type RoutesDefinition map[[2]string]any
|
||||
|
||||
func mockDockerAPIServer(t *testing.T, routes RoutesDefinition) (*httptest.Server, string) {
|
||||
version := "1.51"
|
||||
|
||||
v := func(path string) string {
|
||||
return "/v" + version + path
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead && r.URL.Path == "/_ping" {
|
||||
w.Header().Add("Api-Version", version)
|
||||
_, err := w.Write([]byte{})
|
||||
require.Nil(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
for defs, rValue := range routes {
|
||||
method, path := defs[0], defs[1]
|
||||
if r.Method == method && r.URL.Path == v(path) {
|
||||
require.Nil(t, response.JSON(w, rValue))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
require.NotNil(t, srv)
|
||||
|
||||
return srv, version
|
||||
}
|
||||
|
||||
func TestTransport_getRealResourceID(t *testing.T) {
|
||||
srv, _ := mockDockerAPIServer(t, RoutesDefinition{
|
||||
{http.MethodGet, "/networks"}: []network.Summary{{ID: "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", Name: "mynetwork"}},
|
||||
{http.MethodGet, "/networks/mynetwork"}: network.Inspect{ID: "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", Name: "mynetwork"},
|
||||
{http.MethodGet, "/networks/16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4"}: network.Inspect{ID: "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", Name: "mynetwork"},
|
||||
{http.MethodGet, "/containers/mycontainer/json"}: container.InspectResponse{ContainerJSONBase: &container.ContainerJSONBase{ID: "545fc03ed1fd5008c3bfa2441209ff024e21e396acbeb58b2355930ad1295aa6", Name: "mycontainer"}},
|
||||
{http.MethodGet, "/containers/545fc03ed1fd5008c3bfa2441209ff024e21e396acbeb58b2355930ad1295aa6/json"}: container.InspectResponse{ContainerJSONBase: &container.ContainerJSONBase{ID: "545fc03ed1fd5008c3bfa2441209ff024e21e396acbeb58b2355930ad1295aa6", Name: "mycontainer"}},
|
||||
{http.MethodGet, "/services/myservice"}: swarm.Service{ID: "ibt43uf5awhg06bxp8rkd7bhi", Spec: swarm.ServiceSpec{Annotations: swarm.Annotations{Name: "myservice"}}},
|
||||
{http.MethodGet, "/services/ibt43uf5awhg06bxp8rkd7bhi"}: swarm.Service{ID: "ibt43uf5awhg06bxp8rkd7bhi", Spec: swarm.ServiceSpec{Annotations: swarm.Annotations{Name: "myservice"}}},
|
||||
{http.MethodGet, "/configs/myconfig"}: swarm.Config{ID: "3mlqqza0k413ecebk0mfa11em", Spec: swarm.ConfigSpec{Annotations: swarm.Annotations{Name: "myconfig"}}},
|
||||
{http.MethodGet, "/configs/3mlqqza0k413ecebk0mfa11em"}: swarm.Config{ID: "3mlqqza0k413ecebk0mfa11em", Spec: swarm.ConfigSpec{Annotations: swarm.Annotations{Name: "myconfig"}}},
|
||||
{http.MethodGet, "/secrets/mysecret"}: swarm.Secret{ID: "v9i7o4ivg33u4z3jfyxto162d", Spec: swarm.SecretSpec{Annotations: swarm.Annotations{Name: "mysecret"}}},
|
||||
{http.MethodGet, "/secrets/v9i7o4ivg33u4z3jfyxto162d"}: swarm.Secret{ID: "v9i7o4ivg33u4z3jfyxto162d", Spec: swarm.SecretSpec{Annotations: swarm.Annotations{Name: "mysecret"}}},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
transport := &Transport{
|
||||
endpoint: &portainer.Endpoint{URL: srv.URL},
|
||||
}
|
||||
|
||||
client, err := transport.dockerClientFactory.CreateClient(transport.endpoint, "", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
test := func(rctype portainer.ResourceControlType, name string, id string, errOnUnknown bool) {
|
||||
// by id
|
||||
got, err := getRealResourceID(client, rctype, id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, got)
|
||||
|
||||
// by name
|
||||
got, err = getRealResourceID(client, rctype, name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, got)
|
||||
|
||||
// unknown for this type
|
||||
_, err = getRealResourceID(client, rctype, "unknown")
|
||||
if errOnUnknown {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
test(portainer.NetworkResourceControl, "mynetwork", "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", true)
|
||||
test(portainer.ContainerResourceControl, "mycontainer", "545fc03ed1fd5008c3bfa2441209ff024e21e396acbeb58b2355930ad1295aa6", true)
|
||||
test(portainer.VolumeResourceControl, "anything", "", false)
|
||||
test(portainer.ServiceResourceControl, "myservice", "ibt43uf5awhg06bxp8rkd7bhi", true)
|
||||
test(portainer.ConfigResourceControl, "myconfig", "3mlqqza0k413ecebk0mfa11em", true)
|
||||
test(portainer.SecretResourceControl, "mysecret", "v9i7o4ivg33u4z3jfyxto162d", true)
|
||||
|
||||
// validate that other types are not supported
|
||||
_, err = getRealResourceID(client, portainer.ContainerGroupResourceControl, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTransport_proxyNetworkRequest(t *testing.T) {
|
||||
admin := portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole}
|
||||
std1 := portainer.User{ID: 2, Username: "std1", Role: portainer.StandardUserRole}
|
||||
std2 := portainer.User{ID: 3, Username: "std2", Role: portainer.StandardUserRole}
|
||||
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
require.NoError(t, tx.User().Create(&admin))
|
||||
require.NoError(t, tx.User().Create(&std1))
|
||||
require.NoError(t, tx.User().Create(&std2))
|
||||
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{ID: 1, Name: "env",
|
||||
UserAccessPolicies: portainer.UserAccessPolicies{std1.ID: portainer.AccessPolicy{RoleID: 1}},
|
||||
}))
|
||||
|
||||
require.NoError(t, tx.ResourceControl().Create(authorization.NewPrivateResourceControl("16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", portainer.NetworkResourceControl, std1.ID)))
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
srv, version := mockDockerAPIServer(t, RoutesDefinition{
|
||||
{http.MethodGet, "/networks"}: []network.Summary{{ID: "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", Name: "mynetwork"}},
|
||||
{http.MethodGet, "/networks/mynetwork"}: network.Inspect{ID: "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", Name: "mynetwork"},
|
||||
{http.MethodGet, "/networks/16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4"}: network.Inspect{ID: "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4", Name: "mynetwork"},
|
||||
{http.MethodPost, "/networks/mynetwork/connect"}: struct{}{},
|
||||
{http.MethodPost, "/networks/16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4/connect"}: struct{}{},
|
||||
{http.MethodPost, "/networks/mynetwork/disconnect"}: struct{}{},
|
||||
{http.MethodPost, "/networks/16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4/disconnect"}: struct{}{},
|
||||
{http.MethodDelete, "/networks/mynetwork"}: struct{}{},
|
||||
{http.MethodDelete, "/networks/16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4"}: struct{}{},
|
||||
{http.MethodPost, "/networks/create"}: network.CreateResponse{ID: "16e37c629e88694663791dc738fd37affb908d7b85ce00a20680675d10554fd4"},
|
||||
{http.MethodPost, "/networks/prune"}: struct{}{},
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
transport := &Transport{
|
||||
endpoint: &portainer.Endpoint{URL: srv.URL},
|
||||
dataStore: ds,
|
||||
HTTPTransport: &http.Transport{},
|
||||
}
|
||||
|
||||
test := func(method string, url string, token portainer.TokenData) (*http.Response, error) {
|
||||
req := httptest.NewRequest(method, srv.URL+"/v"+version+url, nil)
|
||||
req = req.WithContext(security.StoreTokenData(req, &token))
|
||||
require.NotNil(t, req)
|
||||
|
||||
return transport.proxyNetworkRequest(req, url)
|
||||
}
|
||||
|
||||
adminToken := portainer.TokenData{ID: admin.ID, Username: admin.Username, Role: admin.Role}
|
||||
std1Token := portainer.TokenData{ID: std1.ID, Username: std1.Username, Role: std1.Role}
|
||||
std2Token := portainer.TokenData{ID: std2.ID, Username: std2.Username, Role: std2.Role}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodGet, "/networks", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
var resp []network.Summary
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&resp))
|
||||
require.Equal(t, 1, len(resp))
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodGet, "/networks", std1Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
var resp []network.Summary
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&resp))
|
||||
require.Equal(t, 1, len(resp))
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodGet, "/networks", std2Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
var resp []network.Summary
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&resp))
|
||||
require.Equal(t, 0, len(resp))
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodGet, "/networks/mynetwork", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodGet, "/networks/mynetwork", std1Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodGet, "/networks/mynetwork", std2Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodGet, "/networks/unknown", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusNotFound, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/mynetwork/connect", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/mynetwork/connect", std1Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.NoError(t, r.Body.Close())
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/mynetwork/connect", std2Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.NoError(t, r.Body.Close())
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/mynetwork/disconnect", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/mynetwork/disconnect", std1Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/mynetwork/disconnect", std2Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodDelete, "/networks/mynetwork", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodDelete, "/networks/mynetwork", std1Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodDelete, "/networks/mynetwork", std2Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/create", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/create", std1Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/create", std2Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/prune", adminToken)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusOK, r.StatusCode)
|
||||
require.NoError(t, r.Body.Close())
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/prune", std1Token)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
if r != nil {
|
||||
r.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
r, err := test(http.MethodPost, "/networks/prune", std2Token)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, r)
|
||||
if r != nil {
|
||||
r.Body.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,14 @@ package kubernetes
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewLocalTransport(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
transport, err := NewLocalTransport(nil, nil, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, transport.httpTransport.TLSClientConfig.InsecureSkipVerify) //nolint:forbidigo
|
||||
|
||||
@@ -23,6 +23,11 @@ var allowedHeaders = map[string]struct{}{
|
||||
"X-Portainer-Volumename": {},
|
||||
"X-Registry-Auth": {},
|
||||
"X-Stream-Protocol-Version": {},
|
||||
// WebSocket headers those are required for kubectl exec/attach/port-forward operations
|
||||
"Sec-Websocket-Key": {},
|
||||
"Sec-Websocket-Version": {},
|
||||
"Sec-Websocket-Protocol": {},
|
||||
"Sec-Websocket-Extensions": {},
|
||||
}
|
||||
|
||||
// newSingleHostReverseProxyWithHostHeader is based on NewSingleHostReverseProxy
|
||||
|
||||
@@ -63,7 +63,10 @@ type errorResponse struct {
|
||||
|
||||
// WriteAccessDeniedResponse will create a new access denied response
|
||||
func WriteAccessDeniedResponse() (*http.Response, error) {
|
||||
response := &http.Response{}
|
||||
header := http.Header{}
|
||||
header.Add("Content-Type", "application/json")
|
||||
|
||||
response := &http.Response{Header: header}
|
||||
err := RewriteResponse(response, errorResponse{Message: "access denied to resource"}, http.StatusForbidden)
|
||||
|
||||
return response, err
|
||||
|
||||
18
api/http/proxy/factory/utils/response_test.go
Normal file
18
api/http/proxy/factory/utils/response_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteAccessDeniedResponse(t *testing.T) {
|
||||
r, err := WriteAccessDeniedResponse()
|
||||
require.NoError(t, err)
|
||||
defer r.Body.Close()
|
||||
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, "application/json", r.Header.Get("content-type"))
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -534,7 +535,7 @@ func MWSecureHeaders(next http.Handler, hsts, csp bool) http.Handler {
|
||||
}
|
||||
|
||||
if csp {
|
||||
w.Header().Set("Content-Security-Policy", "script-src 'self' cdn.matomo.cloud js.hsforms.net; frame-ancestors 'none';")
|
||||
w.Header().Set("Content-Security-Policy", "script-src 'self' cdn.matomo.cloud js.hsforms.net https://www.google.com/recaptcha/, https://www.gstatic.com/recaptcha/; object-src 'none'; frame-ancestors 'none'; frame-src https://www.google.com/recaptcha/ https://www.gstatic.com/recaptcha/")
|
||||
}
|
||||
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
@@ -555,12 +556,9 @@ func (bouncer *RequestBouncer) newRestrictedContextRequest(userID portainer.User
|
||||
return nil, err
|
||||
}
|
||||
|
||||
isTeamLeader := false
|
||||
for _, membership := range memberships {
|
||||
if membership.Role == portainer.TeamLeader {
|
||||
isTeamLeader = true
|
||||
}
|
||||
}
|
||||
isTeamLeader := slices.ContainsFunc(memberships, func(m portainer.TeamMembership) bool {
|
||||
return m.Role == portainer.TeamLeader
|
||||
})
|
||||
|
||||
return &RestrictedRequestContext{
|
||||
IsAdmin: false,
|
||||
|
||||
@@ -111,7 +111,7 @@ func (service *Service) PersistEdgeStack(
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEndpointIds, stack.ID); err != nil {
|
||||
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEndpointIds, stack); err != nil {
|
||||
return nil, fmt.Errorf("unable to add endpoint relations: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ type datastoreOption = func(d *testDatastore)
|
||||
// NewDatastore creates new instance of testDatastore.
|
||||
// Will apply options before returning, opts will be applied from left to right.
|
||||
func NewDatastore(options ...datastoreOption) *testDatastore {
|
||||
conn, _ := database.NewDatabase("boltdb", "", nil)
|
||||
conn, _ := database.NewDatabase("boltdb", "", nil, false)
|
||||
d := testDatastore{connection: conn}
|
||||
|
||||
for _, o := range options {
|
||||
@@ -230,11 +230,11 @@ func (s *stubEndpointRelationService) UpdateEndpointRelation(ID portainer.Endpoi
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubEndpointRelationService) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
func (s *stubEndpointRelationService) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
for _, endpointID := range endpointIDs {
|
||||
for i, r := range s.relations {
|
||||
if r.EndpointID == endpointID {
|
||||
s.relations[i].EdgeStacks[edgeStackID] = true
|
||||
s.relations[i].EdgeStacks[edgeStack.ID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -460,3 +460,39 @@ func WithStacks(stacks []portainer.Stack) datastoreOption {
|
||||
d.stack = &stubStacksService{stacks: stacks}
|
||||
}
|
||||
}
|
||||
|
||||
type stubPendingActionService struct {
|
||||
actions []portainer.PendingAction
|
||||
dataservices.PendingActionsService
|
||||
}
|
||||
|
||||
func WithPendingActions(pendingActions []portainer.PendingAction) datastoreOption {
|
||||
return func(d *testDatastore) {
|
||||
d.pendingActionsService = &stubPendingActionService{
|
||||
actions: pendingActions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stubPendingActionService) ReadAll(predicates ...func(portainer.PendingAction) bool) ([]portainer.PendingAction, error) {
|
||||
filtered := s.actions
|
||||
|
||||
for _, predicate := range predicates {
|
||||
filtered = slicesx.Filter(filtered, predicate)
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (s *stubPendingActionService) Delete(ID portainer.PendingActionID) error {
|
||||
actions := []portainer.PendingAction{}
|
||||
|
||||
for _, action := range s.actions {
|
||||
if action.ID != ID {
|
||||
actions = append(actions, action)
|
||||
}
|
||||
}
|
||||
s.actions = actions
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -145,21 +145,33 @@ func (kcl *KubeClient) GetNonAdminNamespaces(userID int, teamIDs []int, isRestri
|
||||
}
|
||||
|
||||
// GetIsKubeAdmin retrieves true if client is admin
|
||||
func (client *KubeClient) GetIsKubeAdmin() bool {
|
||||
return client.IsKubeAdmin
|
||||
func (kcl *KubeClient) GetIsKubeAdmin() bool {
|
||||
kcl.mu.Lock()
|
||||
defer kcl.mu.Unlock()
|
||||
|
||||
return kcl.isKubeAdmin
|
||||
}
|
||||
|
||||
// UpdateIsKubeAdmin sets whether the kube client is admin
|
||||
func (client *KubeClient) SetIsKubeAdmin(isKubeAdmin bool) {
|
||||
client.IsKubeAdmin = isKubeAdmin
|
||||
func (kcl *KubeClient) SetIsKubeAdmin(isKubeAdmin bool) {
|
||||
kcl.mu.Lock()
|
||||
defer kcl.mu.Unlock()
|
||||
|
||||
kcl.isKubeAdmin = isKubeAdmin
|
||||
}
|
||||
|
||||
// GetClientNonAdminNamespaces retrieves non-admin namespaces
|
||||
func (client *KubeClient) GetClientNonAdminNamespaces() []string {
|
||||
return client.NonAdminNamespaces
|
||||
func (kcl *KubeClient) GetClientNonAdminNamespaces() []string {
|
||||
kcl.mu.Lock()
|
||||
defer kcl.mu.Unlock()
|
||||
|
||||
return kcl.nonAdminNamespaces
|
||||
}
|
||||
|
||||
// UpdateClientNonAdminNamespaces sets the client non admin namespace list
|
||||
func (client *KubeClient) SetClientNonAdminNamespaces(nonAdminNamespaces []string) {
|
||||
client.NonAdminNamespaces = nonAdminNamespaces
|
||||
func (kcl *KubeClient) SetClientNonAdminNamespaces(nonAdminNamespaces []string) {
|
||||
kcl.mu.Lock()
|
||||
defer kcl.mu.Unlock()
|
||||
|
||||
kcl.nonAdminNamespaces = nonAdminNamespaces
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
ktypes "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
kfake "k8s.io/client-go/kubernetes/fake"
|
||||
@@ -65,3 +67,27 @@ func Test_NamespaceAccessPoliciesDeleteNamespace_updatesPortainerConfig_whenConf
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKubeAdmin(t *testing.T) {
|
||||
kcl := &KubeClient{}
|
||||
require.False(t, kcl.GetIsKubeAdmin())
|
||||
|
||||
kcl.SetIsKubeAdmin(true)
|
||||
require.True(t, kcl.GetIsKubeAdmin())
|
||||
|
||||
kcl.SetIsKubeAdmin(false)
|
||||
require.False(t, kcl.GetIsKubeAdmin())
|
||||
}
|
||||
|
||||
func TestClientNonAdminNamespaces(t *testing.T) {
|
||||
kcl := &KubeClient{}
|
||||
|
||||
require.Empty(t, kcl.GetClientNonAdminNamespaces())
|
||||
|
||||
nss := []string{"ns1", "ns2"}
|
||||
kcl.SetClientNonAdminNamespaces(nss)
|
||||
require.Equal(t, nss, kcl.GetClientNonAdminNamespaces())
|
||||
|
||||
kcl.SetClientNonAdminNamespaces([]string{})
|
||||
require.Empty(t, kcl.GetClientNonAdminNamespaces())
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ type PortainerApplicationResources struct {
|
||||
// if the user is an admin, all namespaces in the current k8s environment(endpoint) are fetched using the fetchApplications function.
|
||||
// otherwise, namespaces the non-admin user has access to will be used to filter the applications based on the allowed namespaces.
|
||||
func (kcl *KubeClient) GetApplications(namespace, nodeName string) ([]models.K8sApplication, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchApplications(namespace, nodeName)
|
||||
}
|
||||
|
||||
@@ -64,9 +64,13 @@ func (kcl *KubeClient) fetchApplications(namespace, nodeName string) ([]models.K
|
||||
// fetchApplicationsForNonAdmin fetches the applications in the namespaces the user has access to.
|
||||
// This function is called when the user is not an admin.
|
||||
func (kcl *KubeClient) fetchApplicationsForNonAdmin(namespace, nodeName string) ([]models.K8sApplication, error) {
|
||||
log.Debug().Msgf("Fetching applications for non-admin user: %v", kcl.NonAdminNamespaces)
|
||||
nonAdminNamespaces := kcl.GetClientNonAdminNamespaces()
|
||||
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
log.Debug().
|
||||
Strs("non_admin_namespaces", nonAdminNamespaces).
|
||||
Msg("fetching applications for non-admin user")
|
||||
|
||||
if len(nonAdminNamespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -269,7 +273,8 @@ func populateApplicationFromDeployment(application *models.K8sApplication, deplo
|
||||
application.RunningPodsCount = int(deployment.Status.ReadyReplicas)
|
||||
application.DeploymentType = "Replicated"
|
||||
application.Metadata = &models.Metadata{
|
||||
Labels: deployment.Labels,
|
||||
Labels: deployment.Labels,
|
||||
Annotations: deployment.Annotations,
|
||||
}
|
||||
|
||||
// If the deployment has containers, use the first container's image
|
||||
@@ -297,7 +302,8 @@ func populateApplicationFromStatefulSet(application *models.K8sApplication, stat
|
||||
application.RunningPodsCount = int(statefulSet.Status.ReadyReplicas)
|
||||
application.DeploymentType = "Replicated"
|
||||
application.Metadata = &models.Metadata{
|
||||
Labels: statefulSet.Labels,
|
||||
Labels: statefulSet.Labels,
|
||||
Annotations: statefulSet.Annotations,
|
||||
}
|
||||
|
||||
// If the statefulSet has containers, use the first container's image
|
||||
@@ -322,7 +328,8 @@ func populateApplicationFromDaemonSet(application *models.K8sApplication, daemon
|
||||
application.RunningPodsCount = int(daemonSet.Status.NumberReady)
|
||||
application.DeploymentType = "Global"
|
||||
application.Metadata = &models.Metadata{
|
||||
Labels: daemonSet.Labels,
|
||||
Labels: daemonSet.Labels,
|
||||
Annotations: daemonSet.Annotations,
|
||||
}
|
||||
|
||||
if len(daemonSet.Spec.Template.Spec.Containers) > 0 {
|
||||
@@ -351,7 +358,8 @@ func populateApplicationFromPod(application *models.K8sApplication, pod corev1.P
|
||||
application.RunningPodsCount = runningPodsCount
|
||||
application.DeploymentType = string(pod.Status.Phase)
|
||||
application.Metadata = &models.Metadata{
|
||||
Labels: pod.Labels,
|
||||
Labels: pod.Labels,
|
||||
Annotations: pod.Annotations,
|
||||
}
|
||||
|
||||
// If the pod has containers, use the first container's image
|
||||
|
||||
@@ -310,7 +310,7 @@ func TestGetApplications(t *testing.T) {
|
||||
kubeClient := &KubeClient{
|
||||
cli: fakeClient,
|
||||
instanceID: "test-instance",
|
||||
IsKubeAdmin: true,
|
||||
isKubeAdmin: true,
|
||||
}
|
||||
|
||||
// Test cases
|
||||
@@ -385,8 +385,8 @@ func TestGetApplications(t *testing.T) {
|
||||
kubeClient := &KubeClient{
|
||||
cli: fakeClient,
|
||||
instanceID: "test-instance",
|
||||
IsKubeAdmin: false,
|
||||
NonAdminNamespaces: []string{namespace1},
|
||||
isKubeAdmin: false,
|
||||
nonAdminNamespaces: []string{namespace1},
|
||||
}
|
||||
|
||||
// Test that only resources from allowed namespace are returned
|
||||
@@ -445,7 +445,7 @@ func TestGetApplications(t *testing.T) {
|
||||
kubeClient := &KubeClient{
|
||||
cli: fakeClient,
|
||||
instanceID: "test-instance",
|
||||
IsKubeAdmin: true,
|
||||
isKubeAdmin: true,
|
||||
}
|
||||
|
||||
// Test filtering by node name
|
||||
|
||||
@@ -42,8 +42,8 @@ type (
|
||||
cli kubernetes.Interface
|
||||
instanceID string
|
||||
mu sync.Mutex
|
||||
IsKubeAdmin bool
|
||||
NonAdminNamespaces []string
|
||||
isKubeAdmin bool
|
||||
nonAdminNamespaces []string
|
||||
}
|
||||
)
|
||||
|
||||
@@ -147,6 +147,7 @@ func (factory *ClientFactory) GetProxyKubeClient(endpointID, userID string) (*Ku
|
||||
if ok {
|
||||
return client.(*KubeClient), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -179,8 +180,8 @@ func (factory *ClientFactory) CreateKubeClientFromKubeConfig(clusterID string, k
|
||||
return &KubeClient{
|
||||
cli: cli,
|
||||
instanceID: factory.instanceID,
|
||||
IsKubeAdmin: IsKubeAdmin,
|
||||
NonAdminNamespaces: NonAdminNamespaces,
|
||||
isKubeAdmin: IsKubeAdmin,
|
||||
nonAdminNamespaces: NonAdminNamespaces,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -193,7 +194,7 @@ func (factory *ClientFactory) createCachedPrivilegedKubeClient(endpoint *portain
|
||||
return &KubeClient{
|
||||
cli: cli,
|
||||
instanceID: factory.instanceID,
|
||||
IsKubeAdmin: true,
|
||||
isKubeAdmin: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -371,6 +372,7 @@ func (factory *ClientFactory) MigrateEndpointIngresses(e *portainer.Endpoint, da
|
||||
log.Error().Err(err).Msgf("Error getting ingresses in environment %d", environment.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, ingress := range ingresses {
|
||||
oldController, ok := ingress.Annotations["ingress.portainer.io/ingress-type"]
|
||||
if !ok {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// GetClusterRoles gets all the clusterRoles for at the cluster level in a k8s endpoint.
|
||||
// It returns a list of K8sClusterRole objects.
|
||||
func (kcl *KubeClient) GetClusterRoles() ([]models.K8sClusterRole, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchClusterRoles()
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// GetClusterRoleBindings gets all the clusterRoleBindings for at the cluster level in a k8s endpoint.
|
||||
// It returns a list of K8sClusterRoleBinding objects.
|
||||
func (kcl *KubeClient) GetClusterRoleBindings() ([]models.K8sClusterRoleBinding, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchClusterRoleBindings()
|
||||
}
|
||||
|
||||
|
||||
@@ -16,18 +16,23 @@ import (
|
||||
// if the user is an admin, all configMaps in the current k8s environment(endpoint) are fetched using the fetchConfigMaps function.
|
||||
// otherwise, namespaces the non-admin user has access to will be used to filter the configMaps based on the allowed namespaces.
|
||||
func (kcl *KubeClient) GetConfigMaps(namespace string) ([]models.K8sConfigMap, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchConfigMaps(namespace)
|
||||
}
|
||||
|
||||
return kcl.fetchConfigMapsForNonAdmin(namespace)
|
||||
}
|
||||
|
||||
// fetchConfigMapsForNonAdmin fetches the configMaps in the namespaces the user has access to.
|
||||
// This function is called when the user is not an admin.
|
||||
func (kcl *KubeClient) fetchConfigMapsForNonAdmin(namespace string) ([]models.K8sConfigMap, error) {
|
||||
log.Debug().Msgf("Fetching configMaps for non-admin user: %v", kcl.NonAdminNamespaces)
|
||||
nonAdminNamespaces := kcl.GetClientNonAdminNamespaces()
|
||||
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
log.Debug().
|
||||
Strs("non_admin_namespaces", nonAdminNamespaces).
|
||||
Msg("fetching configMaps for non-admin user")
|
||||
|
||||
if len(nonAdminNamespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// If the user is a kube admin, it returns all cronjobs in the namespace
|
||||
// Otherwise, it returns only the cronjobs in the non-admin namespaces
|
||||
func (kcl *KubeClient) GetCronJobs(namespace string) ([]models.K8sCronJob, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchCronJobs(namespace)
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ func (kcl *KubeClient) TestFetchCronJobs(t *testing.T) {
|
||||
t.Run("admin client can fetch Cron Jobs from all namespaces", func(t *testing.T) {
|
||||
kcl.cli = kfake.NewSimpleClientset()
|
||||
kcl.instanceID = "test"
|
||||
kcl.IsKubeAdmin = true
|
||||
kcl.isKubeAdmin = true
|
||||
|
||||
cronJobs, err := kcl.GetCronJobs("")
|
||||
if err != nil {
|
||||
@@ -31,8 +31,8 @@ func (kcl *KubeClient) TestFetchCronJobs(t *testing.T) {
|
||||
t.Run("non-admin client can fetch Cron Jobs from the default namespace only", func(t *testing.T) {
|
||||
kcl.cli = kfake.NewSimpleClientset()
|
||||
kcl.instanceID = "test"
|
||||
kcl.IsKubeAdmin = false
|
||||
kcl.NonAdminNamespaces = []string{"default"}
|
||||
kcl.isKubeAdmin = false
|
||||
kcl.SetClientNonAdminNamespaces([]string{"default"})
|
||||
|
||||
cronJobs, err := kcl.GetCronJobs("")
|
||||
if err != nil {
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// If the user is a kube admin, it returns all events in the namespace
|
||||
// Otherwise, it returns only the events in the non-admin namespaces
|
||||
func (kcl *KubeClient) GetEvents(namespace string, resourceId string) ([]models.K8sEvent, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchAllEvents(namespace, resourceId)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ func (kcl *KubeClient) GetEvents(namespace string, resourceId string) ([]models.
|
||||
// fetchEventsForNonAdmin returns all events in the given namespace and resource
|
||||
// It returns only the events in the non-admin namespaces
|
||||
func (kcl *KubeClient) fetchEventsForNonAdmin(namespace string, resourceId string) ([]models.K8sEvent, error) {
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
if len(kcl.GetClientNonAdminNamespaces()) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestGetEvents(t *testing.T) {
|
||||
kcl := &KubeClient{
|
||||
cli: kfake.NewSimpleClientset(),
|
||||
instanceID: "instance",
|
||||
IsKubeAdmin: true,
|
||||
isKubeAdmin: true,
|
||||
}
|
||||
event := corev1.Event{
|
||||
InvolvedObject: corev1.ObjectReference{UID: "resourceId"},
|
||||
@@ -49,8 +49,8 @@ func TestGetEvents(t *testing.T) {
|
||||
kcl := &KubeClient{
|
||||
cli: kfake.NewSimpleClientset(),
|
||||
instanceID: "instance",
|
||||
IsKubeAdmin: false,
|
||||
NonAdminNamespaces: []string{"nonAdmin"},
|
||||
isKubeAdmin: false,
|
||||
nonAdminNamespaces: []string{"nonAdmin"},
|
||||
}
|
||||
event := corev1.Event{
|
||||
InvolvedObject: corev1.ObjectReference{UID: "resourceId"},
|
||||
@@ -81,8 +81,8 @@ func TestGetEvents(t *testing.T) {
|
||||
kcl := &KubeClient{
|
||||
cli: kfake.NewSimpleClientset(),
|
||||
instanceID: "instance",
|
||||
IsKubeAdmin: false,
|
||||
NonAdminNamespaces: []string{"nonAdmin"},
|
||||
isKubeAdmin: false,
|
||||
nonAdminNamespaces: []string{"nonAdmin"},
|
||||
}
|
||||
event := corev1.Event{
|
||||
InvolvedObject: corev1.ObjectReference{UID: "resourceId"},
|
||||
|
||||
@@ -12,6 +12,16 @@ import (
|
||||
utilexec "k8s.io/client-go/util/exec"
|
||||
)
|
||||
|
||||
var (
|
||||
channelProtocolList = []string{
|
||||
"v5.channel.k8s.io",
|
||||
"v4.channel.k8s.io",
|
||||
"v3.channel.k8s.io",
|
||||
"v2.channel.k8s.io",
|
||||
"channel.k8s.io",
|
||||
}
|
||||
)
|
||||
|
||||
// StartExecProcess will start an exec process inside a container located inside a pod inside a specific namespace
|
||||
// using the specified command. The stdin parameter will be bound to the stdin process and the stdout process will write
|
||||
// to the stdout parameter.
|
||||
@@ -45,10 +55,18 @@ func (kcl *KubeClient) StartExecProcess(token string, useAdminToken bool, namesp
|
||||
TTY: true,
|
||||
}, scheme.ParameterCodec)
|
||||
|
||||
exec, err := remotecommand.NewSPDYExecutor(config, "POST", req.URL())
|
||||
exec, err := remotecommand.NewWebSocketExecutorForProtocols(
|
||||
config,
|
||||
"GET", // WebSocket uses GET for the upgrade request
|
||||
req.URL().String(),
|
||||
channelProtocolList...,
|
||||
)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
exec, err = remotecommand.NewSPDYExecutor(config, "POST", req.URL())
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = exec.StreamWithContext(context.TODO(), remotecommand.StreamOptions{
|
||||
|
||||
@@ -87,17 +87,22 @@ func (kcl *KubeClient) GetIngress(namespace, ingressName string) (models.K8sIngr
|
||||
|
||||
// GetIngresses gets all the ingresses for a given namespace in a k8s endpoint.
|
||||
func (kcl *KubeClient) GetIngresses(namespace string) ([]models.K8sIngressInfo, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchIngresses(namespace)
|
||||
}
|
||||
|
||||
return kcl.fetchIngressesForNonAdmin(namespace)
|
||||
}
|
||||
|
||||
// fetchIngressesForNonAdmin gets all the ingresses for non-admin users in a k8s endpoint.
|
||||
func (kcl *KubeClient) fetchIngressesForNonAdmin(namespace string) ([]models.K8sIngressInfo, error) {
|
||||
log.Debug().Msgf("Fetching ingresses for non-admin user: %v", kcl.NonAdminNamespaces)
|
||||
nonAdminNamespaces := kcl.GetClientNonAdminNamespaces()
|
||||
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
log.Debug().
|
||||
Strs("non_admin_namespaces", nonAdminNamespaces).
|
||||
Msg("fetching ingresses for non-admin user")
|
||||
|
||||
if len(nonAdminNamespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
15
api/kubernetes/cli/ingress_test.go
Normal file
15
api/kubernetes/cli/ingress_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetIngresses(t *testing.T) {
|
||||
kcl := &KubeClient{}
|
||||
|
||||
ingresses, err := kcl.GetIngresses("default")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, ingresses)
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
// If the user is a kube admin, it returns all jobs in the namespace
|
||||
// Otherwise, it returns only the jobs in the non-admin namespaces
|
||||
func (kcl *KubeClient) GetJobs(namespace string, includeCronJobChildren bool) ([]models.K8sJob, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchJobs(namespace, includeCronJobChildren)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ func (kcl *KubeClient) TestFetchJobs(t *testing.T) {
|
||||
t.Run("admin client can fetch jobs from all namespaces", func(t *testing.T) {
|
||||
kcl.cli = kfake.NewSimpleClientset()
|
||||
kcl.instanceID = "test"
|
||||
kcl.IsKubeAdmin = true
|
||||
kcl.isKubeAdmin = true
|
||||
|
||||
jobs, err := kcl.GetJobs("", false)
|
||||
if err != nil {
|
||||
@@ -34,8 +34,8 @@ func (kcl *KubeClient) TestFetchJobs(t *testing.T) {
|
||||
t.Run("non-admin client can fetch jobs from the default namespace only", func(t *testing.T) {
|
||||
kcl.cli = kfake.NewSimpleClientset()
|
||||
kcl.instanceID = "test"
|
||||
kcl.IsKubeAdmin = false
|
||||
kcl.NonAdminNamespaces = []string{"default"}
|
||||
kcl.isKubeAdmin = false
|
||||
kcl.SetClientNonAdminNamespaces([]string{"default"})
|
||||
|
||||
jobs, err := kcl.GetJobs("", false)
|
||||
if err != nil {
|
||||
|
||||
@@ -40,9 +40,10 @@ func defaultSystemNamespaces() map[string]struct{} {
|
||||
// if the user is an admin, all namespaces in the current k8s environment(endpoint) are fetched using the fetchNamespaces function.
|
||||
// otherwise, namespaces the non-admin user has access to will be used to filter the namespaces based on the allowed namespaces.
|
||||
func (kcl *KubeClient) GetNamespaces() (map[string]portainer.K8sNamespaceInfo, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchNamespaces()
|
||||
}
|
||||
|
||||
return kcl.fetchNamespacesForNonAdmin()
|
||||
}
|
||||
|
||||
@@ -52,7 +53,7 @@ func (kcl *KubeClient) fetchNamespacesForNonAdmin() (map[string]portainer.K8sNam
|
||||
Str("context", "fetchNamespacesForNonAdmin").
|
||||
Msg("Fetching namespaces for non-admin user")
|
||||
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
if len(kcl.GetClientNonAdminNamespaces()) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -142,6 +143,7 @@ func (kcl *KubeClient) CreateNamespace(info models.K8sNamespaceDetails) (*corev1
|
||||
Str("context", "CreateNamespace").
|
||||
Str("Namespace", info.Name).
|
||||
Msg("Failed to create the namespace")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -157,7 +159,7 @@ func (kcl *KubeClient) CreateNamespace(info models.K8sNamespaceDetails) (*corev1
|
||||
return namespace, nil
|
||||
}
|
||||
|
||||
// UpdateIngress updates an ingress in a given namespace in a k8s endpoint.
|
||||
// UpdateNamespace updates a namespace in a k8s endpoint.
|
||||
func (kcl *KubeClient) UpdateNamespace(info models.K8sNamespaceDetails) (*corev1.Namespace, error) {
|
||||
portainerLabels := map[string]string{
|
||||
namespaceNameLabel: stackutils.SanitizeLabel(info.Name),
|
||||
@@ -420,8 +422,10 @@ func (kcl *KubeClient) CombineNamespaceWithResourceQuota(namespace portainer.K8s
|
||||
// buildNonAdminNamespacesMap builds a map of non-admin namespaces.
|
||||
// the map is used to filter the namespaces based on the allowed namespaces.
|
||||
func (kcl *KubeClient) buildNonAdminNamespacesMap() map[string]struct{} {
|
||||
nonAdminNamespaceSet := make(map[string]struct{}, len(kcl.NonAdminNamespaces))
|
||||
for _, namespace := range kcl.NonAdminNamespaces {
|
||||
nonAdminNamespaces := kcl.GetClientNonAdminNamespaces()
|
||||
nonAdminNamespaceSet := make(map[string]struct{}, len(nonAdminNamespaces))
|
||||
|
||||
for _, namespace := range nonAdminNamespaces {
|
||||
if !isSystemDefaultNamespace(namespace) {
|
||||
nonAdminNamespaceSet[namespace] = struct{}{}
|
||||
}
|
||||
|
||||
@@ -176,6 +176,7 @@ func Test_ToggleSystemState(t *testing.T) {
|
||||
expectedPolicies := map[string]portainer.K8sNamespaceAccessPolicy{
|
||||
"ns2": {UserAccessPolicies: portainer.UserAccessPolicies{2: {RoleID: 0}}},
|
||||
}
|
||||
|
||||
actualPolicies, err := kcl.GetNamespaceAccessPolicies()
|
||||
assert.NoError(t, err, "failed to fetch policies")
|
||||
assert.Equal(t, expectedPolicies, actualPolicies)
|
||||
|
||||
@@ -46,9 +46,9 @@ func (kcl *KubeClient) GetNodesLimits() (portainer.K8sNodesLimits, error) {
|
||||
|
||||
// GetMaxResourceLimits gets the maximum CPU and Memory limits(unused resources) of all nodes in the current k8s environment(endpoint) connection, minus the accumulated resourcequotas for all namespaces except the one we're editing (skipNamespace)
|
||||
// if skipNamespace is set to "" then all namespaces are considered
|
||||
func (client *KubeClient) GetMaxResourceLimits(skipNamespace string, overCommitEnabled bool, resourceOverCommitPercent int) (portainer.K8sNodeLimits, error) {
|
||||
func (kcl *KubeClient) GetMaxResourceLimits(skipNamespace string, overCommitEnabled bool, resourceOverCommitPercent int) (portainer.K8sNodeLimits, error) {
|
||||
limits := portainer.K8sNodeLimits{}
|
||||
nodes, err := client.cli.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{})
|
||||
nodes, err := kcl.cli.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{})
|
||||
if err != nil {
|
||||
return limits, err
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func (client *KubeClient) GetMaxResourceLimits(skipNamespace string, overCommitE
|
||||
limits.Memory = memory / 1000000 // B to MB
|
||||
|
||||
if !overCommitEnabled {
|
||||
namespaces, err := client.cli.CoreV1().Namespaces().List(context.TODO(), metav1.ListOptions{})
|
||||
namespaces, err := kcl.cli.CoreV1().Namespaces().List(context.TODO(), metav1.ListOptions{})
|
||||
if err != nil {
|
||||
return limits, err
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (client *KubeClient) GetMaxResourceLimits(skipNamespace string, overCommitE
|
||||
}
|
||||
|
||||
// minus accumulated resourcequotas for all namespaces except the one we're editing
|
||||
resourceQuota, err := client.cli.CoreV1().ResourceQuotas(namespace.Name).List(context.TODO(), metav1.ListOptions{})
|
||||
resourceQuota, err := kcl.cli.CoreV1().ResourceQuotas(namespace.Name).List(context.TODO(), metav1.ListOptions{})
|
||||
if err != nil {
|
||||
log.Debug().Msgf("error getting resourcequota for namespace %s: %s", namespace.Name, err)
|
||||
continue // skip it
|
||||
|
||||
@@ -59,6 +59,7 @@ func Test_waitForPodStatus(t *testing.T) {
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.TODO(), 0*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
err = k.waitForPodStatus(ctx, v1.PodRunning, podSpec)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Errorf("waitForPodStatus should throw deadline exceeded error; err=%s", err)
|
||||
|
||||
@@ -15,18 +15,23 @@ import (
|
||||
// if the user is an admin, all resource quotas in all namespaces are fetched.
|
||||
// otherwise, namespaces the non-admin user has access to will be used to filter the resource quotas.
|
||||
func (kcl *KubeClient) GetResourceQuotas(namespace string) (*[]corev1.ResourceQuota, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchResourceQuotas(namespace)
|
||||
}
|
||||
|
||||
return kcl.fetchResourceQuotasForNonAdmin(namespace)
|
||||
}
|
||||
|
||||
// fetchResourceQuotasForNonAdmin gets the resource quotas in the current k8s environment(endpoint) for a non-admin user.
|
||||
// the role of the user must have read access to the resource quotas in the defined namespaces.
|
||||
func (kcl *KubeClient) fetchResourceQuotasForNonAdmin(namespace string) (*[]corev1.ResourceQuota, error) {
|
||||
log.Debug().Msgf("Fetching resource quotas for non-admin user: %v", kcl.NonAdminNamespaces)
|
||||
nonAdminNamespaces := kcl.GetClientNonAdminNamespaces()
|
||||
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
log.Debug().
|
||||
Strs("non_admin_namespaces", nonAdminNamespaces).
|
||||
Msg("fetching resource quotas for non-admin user")
|
||||
|
||||
if len(nonAdminNamespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// GetRoles gets all the roles for either at the cluster level or a given namespace in a k8s endpoint.
|
||||
// It returns a list of K8sRole objects.
|
||||
func (kcl *KubeClient) GetRoles(namespace string) ([]models.K8sRole, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchRoles(namespace)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// GetRoleBindings gets all the roleBindings for either at the cluster level or a given namespace in a k8s endpoint.
|
||||
// It returns a list of K8sRoleBinding objects.
|
||||
func (kcl *KubeClient) GetRoleBindings(namespace string) ([]models.K8sRoleBinding, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchRoleBindings(namespace)
|
||||
}
|
||||
|
||||
|
||||
@@ -23,18 +23,23 @@ const (
|
||||
// if the user is an admin, all secrets in the current k8s environment(endpoint) are fetched using the getSecrets function.
|
||||
// otherwise, namespaces the non-admin user has access to will be used to filter the secrets based on the allowed namespaces.
|
||||
func (kcl *KubeClient) GetSecrets(namespace string) ([]models.K8sSecret, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.getSecrets(namespace)
|
||||
}
|
||||
|
||||
return kcl.getSecretsForNonAdmin(namespace)
|
||||
}
|
||||
|
||||
// getSecretsForNonAdmin fetches the secrets in the namespaces the user has access to.
|
||||
// This function is called when the user is not an admin.
|
||||
func (kcl *KubeClient) getSecretsForNonAdmin(namespace string) ([]models.K8sSecret, error) {
|
||||
log.Debug().Msgf("Fetching secrets for non-admin user: %v", kcl.NonAdminNamespaces)
|
||||
nonAdminNamespaces := kcl.GetClientNonAdminNamespaces()
|
||||
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
log.Debug().
|
||||
Strs("non_admin_namespaces", nonAdminNamespaces).
|
||||
Msg("fetching secrets for non-admin user")
|
||||
|
||||
if len(nonAdminNamespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,9 +15,10 @@ import (
|
||||
// GetServices gets all the services for either at the cluster level or a given namespace in a k8s endpoint.
|
||||
// It returns a list of K8sServiceInfo objects.
|
||||
func (kcl *KubeClient) GetServices(namespace string) ([]models.K8sServiceInfo, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchServices(namespace)
|
||||
}
|
||||
|
||||
return kcl.fetchServicesForNonAdmin(namespace)
|
||||
}
|
||||
|
||||
@@ -25,9 +26,13 @@ func (kcl *KubeClient) GetServices(namespace string) ([]models.K8sServiceInfo, e
|
||||
// the namespace will be coming from NonAdminNamespaces as non-admin users are restricted to certain namespaces.
|
||||
// it returns a list of K8sServiceInfo objects.
|
||||
func (kcl *KubeClient) fetchServicesForNonAdmin(namespace string) ([]models.K8sServiceInfo, error) {
|
||||
log.Debug().Msgf("Fetching services for non-admin user: %v", kcl.NonAdminNamespaces)
|
||||
nonAdminNamespaces := kcl.GetClientNonAdminNamespaces()
|
||||
|
||||
if len(kcl.NonAdminNamespaces) == 0 {
|
||||
log.Debug().
|
||||
Strs("non_admin_namespaces", nonAdminNamespaces).
|
||||
Msg("fetching services for non-admin user")
|
||||
|
||||
if len(nonAdminNamespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// GetServiceAccounts gets all the service accounts for either at the cluster level or a given namespace in a k8s endpoint.
|
||||
// It returns a list of K8sServiceAccount objects.
|
||||
func (kcl *KubeClient) GetServiceAccounts(namespace string) ([]models.K8sServiceAccount, error) {
|
||||
if kcl.IsKubeAdmin {
|
||||
if kcl.GetIsKubeAdmin() {
|
||||
return kcl.fetchServiceAccounts(namespace)
|
||||
}
|
||||
|
||||
|
||||
15
api/kubernetes/cli/service_test.go
Normal file
15
api/kubernetes/cli/service_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetServices(t *testing.T) {
|
||||
kcl := &KubeClient{}
|
||||
|
||||
services, err := kcl.GetServices("default")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, services)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user