Compare commits
256 Commits
yd-develop
...
2.33.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb63a7ad7c | ||
|
|
76e4054215 | ||
|
|
0a3e13915c | ||
|
|
dbd6e49e5f | ||
|
|
abad58a370 | ||
|
|
4eb1c7b11f | ||
|
|
3afedce570 | ||
|
|
a7b6db72a5 | ||
|
|
9c79d6dc7d | ||
|
|
11f612a501 | ||
|
|
cb8d8fcfd6 | ||
|
|
22bb1e604d | ||
|
|
970b135261 | ||
|
|
a69470ec08 | ||
|
|
ea6f1c97f5 | ||
|
|
6d058987f3 | ||
|
|
6998f05855 | ||
|
|
94d01c58fc | ||
|
|
d98eb77067 | ||
|
|
941e86563a | ||
|
|
f72d6b97d3 | ||
|
|
32926aa8bf | ||
|
|
1849c61c38 | ||
|
|
fd6d74602c | ||
|
|
74b1dd04d1 | ||
|
|
7450501b7a | ||
|
|
dcfe2d9809 | ||
|
|
c21c91632f | ||
|
|
732337615e | ||
|
|
6ea16c0060 | ||
|
|
4e7d4b60a5 | ||
|
|
19e1cc2fbd | ||
|
|
68b9fef3f0 | ||
|
|
1e47df6611 | ||
|
|
405ce8f671 | ||
|
|
e9d31b3b7b | ||
|
|
f97adc94ad | ||
|
|
11d6341765 | ||
|
|
c3cf46b0e0 | ||
|
|
ff746beba1 | ||
|
|
da1672fc17 | ||
|
|
7a9376cbaf | ||
|
|
c0f6410d80 | ||
|
|
4b9ab98fd2 | ||
|
|
3354ee4e4b | ||
|
|
af3c45bea0 | ||
|
|
816a6f9bef | ||
|
|
e86ea22900 | ||
|
|
12b2acbc00 | ||
|
|
4a8b42928e | ||
|
|
2e828b39da | ||
|
|
49c6521c23 | ||
|
|
debf1a742b | ||
|
|
5d3708ec3e | ||
|
|
9320fd4c50 | ||
|
|
974682bd98 | ||
|
|
631f1deb2e | ||
|
|
4169b045fb | ||
|
|
0a2a786aa3 | ||
|
|
808f87206e | ||
|
|
ed6fa82904 | ||
|
|
9fc301110b | ||
|
|
69101ac89a | ||
|
|
69d33dd432 | ||
|
|
389cbf748c | ||
|
|
d01b31f707 | ||
|
|
3ade5cdf19 | ||
|
|
5f6fa4d79f | ||
|
|
3ee20863d6 | ||
|
|
8fe5eaee29 | ||
|
|
208534c9d9 | ||
|
|
3f030394c6 | ||
|
|
6ca0085ec8 | ||
|
|
2cf1649c67 | ||
|
|
64ed988169 | ||
|
|
85b7e881eb | ||
|
|
9325cb2872 | ||
|
|
e39dcc458b | ||
|
|
84b4b30f21 | ||
|
|
6c47598cd9 | ||
|
|
d00d71ecbf | ||
|
|
dc273b2d63 | ||
|
|
497b16e942 | ||
|
|
a472de1919 | ||
|
|
d306d7a983 | ||
|
|
163aa57e5c | ||
|
|
3eab294908 | ||
|
|
da30780ac2 | ||
|
|
ef53354193 | ||
|
|
e9ce3d2213 | ||
|
|
a46db61c4c | ||
|
|
5e271fd4a4 | ||
|
|
6481483074 | ||
|
|
7bcb37c761 | ||
|
|
e7d97d7a2b | ||
|
|
1afae99345 | ||
|
|
bdb2e2f417 | ||
|
|
bba3751268 | ||
|
|
60bc04bc33 | ||
|
|
a4cff13531 | ||
|
|
937456596a | ||
|
|
caf382b64c | ||
|
|
55cc250d2e | ||
|
|
eaa2be017d | ||
|
|
4e4c5ffdb6 | ||
|
|
383bcc4113 | ||
|
|
9f906b7417 | ||
|
|
db2e168540 | ||
|
|
2697d6c5d7 | ||
|
|
b6a6ce9aaf | ||
|
|
89f6a94bd8 | ||
|
|
96f2d69ae5 | ||
|
|
b7e906701a | ||
|
|
150d986179 | ||
|
|
ef10ea2a7d | ||
|
|
3bf84e8b0c | ||
|
|
ea4b334c7e | ||
|
|
4d11aa8655 | ||
|
|
302deb8299 | ||
|
|
0c80b1067d | ||
|
|
0a36d4fbfd | ||
|
|
c20a8b5a68 | ||
|
|
8ffe4e284a | ||
|
|
1332f718ae | ||
|
|
f4df51884c | ||
|
|
ce86129478 | ||
|
|
097b125e3a | ||
|
|
5c6b53922a | ||
|
|
e1b9f23f73 | ||
|
|
e1c480d3c3 | ||
|
|
363a62d885 | ||
|
|
c6ee9a5a52 | ||
|
|
cf5990ccba | ||
|
|
b6f3682a62 | ||
|
|
b43f864511 | ||
|
|
0556ffb4a1 | ||
|
|
303047656e | ||
|
|
8d29b5ae71 | ||
|
|
7d7ae24351 | ||
|
|
97838e614d | ||
|
|
c897baad20 | ||
|
|
d51e9205d9 | ||
|
|
e051c86bb5 | ||
|
|
c2b48cd003 | ||
|
|
a7009eb8d5 | ||
|
|
036b87b649 | ||
|
|
f07a3b1875 | ||
|
|
6e89ccc0ae | ||
|
|
cc67612432 | ||
|
|
17ebe221bb | ||
|
|
1963edda66 | ||
|
|
c9e3717ce3 | ||
|
|
9a85246631 | ||
|
|
75f165d1ff | ||
|
|
eaf0deb2f6 | ||
|
|
a9061e5258 | ||
|
|
caac45b834 | ||
|
|
24ff7a7911 | ||
|
|
b767dcb27e | ||
|
|
731afbee46 | ||
|
|
07dfd981a2 | ||
|
|
32ef208278 | ||
|
|
a80b185e10 | ||
|
|
b96328e098 | ||
|
|
45471ce86d | ||
|
|
1bc91d0c7c | ||
|
|
799325d9f8 | ||
|
|
b540709e03 | ||
|
|
44daab04ac | ||
|
|
ee65223ee7 | ||
|
|
d49fcd8f3e | ||
|
|
4ee349bd6b | ||
|
|
dfa32b6755 | ||
|
|
0b69729173 | ||
|
|
3b313b9308 | ||
|
|
1abdf42f99 | ||
|
|
9fdc535d6b | ||
|
|
b9b734ceda | ||
|
|
3b05505527 | ||
|
|
bc29419c17 | ||
|
|
4d4360b86b | ||
|
|
8cc28761d7 | ||
|
|
24b3499c70 | ||
|
|
4e4fd5a4b4 | ||
|
|
1a3df54c04 | ||
|
|
3edacee59b | ||
|
|
f25d31b92b | ||
|
|
c91c8a6467 | ||
|
|
61d6ac035d | ||
|
|
9a9373dd0f | ||
|
|
e319a7a5ae | ||
|
|
342549b546 | ||
|
|
bbe94f55b6 | ||
|
|
6fcf1893d3 | ||
|
|
01afe34df7 | ||
|
|
be3e8e3332 | ||
|
|
cf31700903 | ||
|
|
66dee6fd06 | ||
|
|
bfa55f8c67 | ||
|
|
5a2318d01f | ||
|
|
7de037029f | ||
|
|
730c1115ce | ||
|
|
2c37f32fa6 | ||
|
|
7aa9f8b1c3 | ||
|
|
c331ada086 | ||
|
|
ebc25e45d3 | ||
|
|
f82921d2a1 | ||
|
|
d68fe42918 | ||
|
|
823f2a7991 | ||
|
|
0ca9321db1 | ||
|
|
46eddbe7b9 | ||
|
|
64c796a8c3 | ||
|
|
264ff5457b | ||
|
|
ad89df4d0d | ||
|
|
0f10b8ba2b | ||
|
|
940bf990f9 | ||
|
|
1b8fbbe7d7 | ||
|
|
f6f07f4690 | ||
|
|
3800249921 | ||
|
|
a5d857d5e7 | ||
|
|
4c1e80ff58 | ||
|
|
7e5db1f55e | ||
|
|
1edc56c0ce | ||
|
|
4066a70ea5 | ||
|
|
a0d36cf87a | ||
|
|
1d12011eb5 | ||
|
|
7c01f84a5c | ||
|
|
81c5f4acc3 | ||
|
|
0ebfe047d1 | ||
|
|
e68bd53e30 | ||
|
|
cdd9851f72 | ||
|
|
995c3ef81b | ||
|
|
0dfde1374d | ||
|
|
34235199dd | ||
|
|
5d1cd670e9 | ||
|
|
1d8ea7b0ee | ||
|
|
4b218553c3 | ||
|
|
a61c1004d3 | ||
|
|
5d1b42b314 | ||
|
|
4b992c6f3e | ||
|
|
38562f9560 | ||
|
|
c01f0271fe | ||
|
|
0296998fae | ||
|
|
a67b917bdd | ||
|
|
2791bd123c | ||
|
|
e1f9b69cd5 | ||
|
|
2c05496962 | ||
|
|
66bcf9223a | ||
|
|
993f69db37 | ||
|
|
58317edb6d | ||
|
|
417891675d | ||
|
|
8b7aef883a | ||
|
|
b5961d79f8 | ||
|
|
0d25f3f430 | ||
|
|
798fa2396a | ||
|
|
28b222fffa |
@@ -17,7 +17,7 @@ plugins:
|
||||
- import
|
||||
|
||||
parserOptions:
|
||||
ecmaVersion: 2018
|
||||
ecmaVersion: latest
|
||||
sourceType: module
|
||||
project: './tsconfig.json'
|
||||
ecmaFeatures:
|
||||
|
||||
45
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
45
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -2,18 +2,17 @@ name: Bug Report
|
||||
description: Create a report to help us improve.
|
||||
labels: kind/bug,bug/need-confirmation
|
||||
body:
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# Welcome!
|
||||
|
||||
|
||||
The issue tracker is for reporting bugs. If you have an [idea for a new feature](https://github.com/orgs/portainer/discussions/categories/ideas) or a [general question about Portainer](https://github.com/orgs/portainer/discussions/categories/help) please post in our [GitHub Discussions](https://github.com/orgs/portainer/discussions).
|
||||
|
||||
|
||||
You can also ask for help in our [community Slack channel](https://join.slack.com/t/portainer/shared_invite/zt-txh3ljab-52QHTyjCqbe5RibC2lcjKA).
|
||||
|
||||
Please note that we only provide support for current versions of Portainer. You can find a list of supported versions in our [lifecycle policy](https://docs.portainer.io/start/lifecycle).
|
||||
|
||||
|
||||
**DO NOT FILE ISSUES FOR GENERAL SUPPORT QUESTIONS**.
|
||||
|
||||
- type: checkboxes
|
||||
@@ -45,7 +44,7 @@ body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Problem Description
|
||||
description: A clear and concise description of what the bug is.
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -71,7 +70,7 @@ body:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
4. See error
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -92,9 +91,29 @@ body:
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Portainer version
|
||||
description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [upgrading first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
|
||||
description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [updating first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
|
||||
multiple: false
|
||||
options:
|
||||
- '2.32.0'
|
||||
- '2.31.3'
|
||||
- '2.31.2'
|
||||
- '2.31.1'
|
||||
- '2.31.0'
|
||||
- '2.30.1'
|
||||
- '2.30.0'
|
||||
- '2.29.2'
|
||||
- '2.29.1'
|
||||
- '2.29.0'
|
||||
- '2.28.1'
|
||||
- '2.28.0'
|
||||
- '2.27.9'
|
||||
- '2.27.8'
|
||||
- '2.27.7'
|
||||
- '2.27.6'
|
||||
- '2.27.5'
|
||||
- '2.27.4'
|
||||
- '2.27.3'
|
||||
- '2.27.2'
|
||||
- '2.27.1'
|
||||
- '2.27.0'
|
||||
- '2.26.1'
|
||||
@@ -111,16 +130,6 @@ body:
|
||||
- '2.21.2'
|
||||
- '2.21.1'
|
||||
- '2.21.0'
|
||||
- '2.20.3'
|
||||
- '2.20.2'
|
||||
- '2.20.1'
|
||||
- '2.20.0'
|
||||
- '2.19.5'
|
||||
- '2.19.4'
|
||||
- '2.19.3'
|
||||
- '2.19.2'
|
||||
- '2.19.1'
|
||||
- '2.19.0'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -158,7 +167,7 @@ body:
|
||||
- type: input
|
||||
attributes:
|
||||
label: Browser
|
||||
description: |
|
||||
description: |
|
||||
Enter your browser and version. Example: Google Chrome 114.0
|
||||
validations:
|
||||
required: false
|
||||
|
||||
@@ -1,40 +1,59 @@
|
||||
version: "2"
|
||||
linters:
|
||||
# Disable all linters, the defaults don't pass on our code yet
|
||||
disable-all: true
|
||||
|
||||
# Enable these for now
|
||||
default: none
|
||||
enable:
|
||||
- unused
|
||||
- depguard
|
||||
- gosimple
|
||||
- govet
|
||||
- errorlint
|
||||
- bodyclose
|
||||
- copyloopvar
|
||||
- depguard
|
||||
- errorlint
|
||||
- forbidigo
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
- perfsprint
|
||||
|
||||
linters-settings:
|
||||
depguard:
|
||||
rules:
|
||||
main:
|
||||
deny:
|
||||
- pkg: 'encoding/json'
|
||||
desc: 'use github.com/segmentio/encoding/json'
|
||||
- pkg: 'golang.org/x/exp'
|
||||
desc: 'exp is not allowed'
|
||||
- pkg: 'github.com/portainer/libcrypto'
|
||||
desc: 'use github.com/portainer/portainer/pkg/libcrypto'
|
||||
- pkg: 'github.com/portainer/libhttp'
|
||||
desc: 'use github.com/portainer/portainer/pkg/libhttp'
|
||||
files:
|
||||
- '!**/*_test.go'
|
||||
- '!**/base.go'
|
||||
- '!**/base_tx.go'
|
||||
|
||||
# errorlint is causing a typecheck error for some reason. The go compiler will report these
|
||||
# anyway, so ignore them from the linter
|
||||
issues:
|
||||
exclude-rules:
|
||||
- path: ./
|
||||
linters:
|
||||
- typecheck
|
||||
- staticcheck
|
||||
- unused
|
||||
settings:
|
||||
staticcheck:
|
||||
checks: ["all", "-ST1003", "-ST1005", "-ST1016", "-SA1019", "-QF1003"]
|
||||
depguard:
|
||||
rules:
|
||||
main:
|
||||
files:
|
||||
- '!**/*_test.go'
|
||||
- '!**/base.go'
|
||||
- '!**/base_tx.go'
|
||||
deny:
|
||||
- pkg: encoding/json
|
||||
desc: use github.com/segmentio/encoding/json
|
||||
- pkg: golang.org/x/exp
|
||||
desc: exp is not allowed
|
||||
- pkg: github.com/portainer/libcrypto
|
||||
desc: use github.com/portainer/portainer/pkg/libcrypto
|
||||
- pkg: github.com/portainer/libhttp
|
||||
desc: use github.com/portainer/portainer/pkg/libhttp
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^tls\.Config$
|
||||
msg: Use crypto.CreateTLSConfiguration() instead
|
||||
- pattern: ^tls\.Config\.(InsecureSkipVerify|MinVersion|MaxVersion|CipherSuites|CurvePreferences)$
|
||||
msg: Do not set this field directly, use crypto.CreateTLSConfiguration() instead
|
||||
analyze-types: true
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
17
README.md
17
README.md
@@ -8,9 +8,9 @@ Portainer consists of a single container that can run on any cluster. It can be
|
||||
|
||||
**Portainer Business Edition** builds on the open-source base and includes a range of advanced features and functions (like RBAC and Support) that are specific to the needs of business users.
|
||||
|
||||
- [Compare Portainer CE and Compare Portainer BE](https://portainer.io/products)
|
||||
- [Compare Portainer CE and Compare Portainer BE](https://www.portainer.io/features)
|
||||
- [Take3 – get 3 free nodes of Portainer Business for as long as you want them](https://www.portainer.io/take-3)
|
||||
- [Portainer BE install guide](https://install.portainer.io)
|
||||
- [Portainer BE install guide](https://academy.portainer.io/install/)
|
||||
|
||||
## Latest Version
|
||||
|
||||
@@ -20,22 +20,19 @@ Portainer CE is updated regularly. We aim to do an update release every couple o
|
||||
|
||||
## Getting started
|
||||
|
||||
- [Deploy Portainer](https://docs.portainer.io/start/install)
|
||||
- [Deploy Portainer](https://docs.portainer.io/start/install-ce)
|
||||
- [Documentation](https://docs.portainer.io)
|
||||
- [Contribute to the project](https://docs.portainer.io/contribute/contribute)
|
||||
|
||||
## Features & Functions
|
||||
|
||||
View [this](https://www.portainer.io/products) table to see all of the Portainer CE functionality and compare to Portainer Business.
|
||||
|
||||
- [Portainer CE for Docker / Docker Swarm](https://www.portainer.io/solutions/docker)
|
||||
- [Portainer CE for Kubernetes](https://www.portainer.io/solutions/kubernetes-ui)
|
||||
View [this](https://www.portainer.io/features) table to see all of the Portainer CE functionality and compare to Portainer Business.
|
||||
|
||||
## Getting help
|
||||
|
||||
Portainer CE is an open source project and is supported by the community. You can buy a supported version of Portainer at portainer.io
|
||||
|
||||
Learn more about Portainer's community support channels [here.](https://www.portainer.io/get-support-for-portainer)
|
||||
Learn more about Portainer's community support channels [here.](https://www.portainer.io/resources/get-help/get-support)
|
||||
|
||||
- Issues: https://github.com/portainer/portainer/issues
|
||||
- Slack (chat): [https://portainer.io/slack](https://portainer.io/slack)
|
||||
@@ -53,13 +50,13 @@ You can join the Portainer Community by visiting [https://www.portainer.io/join-
|
||||
|
||||
## Work for us
|
||||
|
||||
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to info@portainer.io with your details and/or visit our [careers page](https://portainer.io/careers).
|
||||
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to success@portainer.io with your details and/or visit our [careers page](https://apply.workable.com/portainer/).
|
||||
|
||||
## Privacy
|
||||
|
||||
**To make sure we focus our development effort in the right places we need to know which features get used most often. To give us this information we use [Matomo Analytics](https://matomo.org/), which is hosted in Germany and is fully GDPR compliant.**
|
||||
|
||||
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/privacy-policy). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
|
||||
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/legal/privacy-policy). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
|
||||
|
||||
## Limitations
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// GetAgentVersionAndPlatform returns the agent version and platform
|
||||
//
|
||||
// it sends a ping to the agent and parses the version and platform from the headers
|
||||
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) {
|
||||
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
|
||||
httpCli := &http.Client{
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package archive
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -12,50 +11,6 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// UnzipArchive will unzip an archive from bytes into the dest destination folder on disk
|
||||
func UnzipArchive(archiveData []byte, dest string) error {
|
||||
zipReader, err := zip.NewReader(bytes.NewReader(archiveData), int64(len(archiveData)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, zipFile := range zipReader.File {
|
||||
err := extractFileFromArchive(zipFile, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractFileFromArchive(file *zip.File, dest string) error {
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fpath := filepath.Join(dest, file.Name)
|
||||
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(outFile, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return outFile.Close()
|
||||
}
|
||||
|
||||
// UnzipFile will decompress a zip archive, moving all files and folders
|
||||
// within the zip file (parameter 1) to an output directory (parameter 2).
|
||||
func UnzipFile(src string, dest string) error {
|
||||
@@ -76,11 +31,11 @@ func UnzipFile(src string, dest string) error {
|
||||
if f.FileInfo().IsDir() {
|
||||
// Make Folder
|
||||
os.MkdirAll(p, os.ModePerm)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = unzipFile(f, p)
|
||||
if err != nil {
|
||||
if err := unzipFile(f, p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -93,20 +48,20 @@ func unzipFile(f *zip.File, p string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(p), os.ModePerm); err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't make a path %s", p)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't create file %s", p)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't open zip file %s in the archive", f.Name)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
_, err = io.Copy(outFile, rc)
|
||||
|
||||
if err != nil {
|
||||
if _, err = io.Copy(outFile, rc); err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't copy an archived file content")
|
||||
}
|
||||
|
||||
|
||||
@@ -54,8 +54,8 @@ func ecdsaGenerateKey(c elliptic.Curve, rand io.Reader) (*ecdsa.PrivateKey, erro
|
||||
}
|
||||
|
||||
priv := new(ecdsa.PrivateKey)
|
||||
priv.PublicKey.Curve = c
|
||||
priv.Curve = c
|
||||
priv.D = k
|
||||
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
|
||||
priv.X, priv.Y = c.ScalarBaseMult(k.Bytes())
|
||||
return priv, nil
|
||||
}
|
||||
|
||||
@@ -9,10 +9,15 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
|
||||
func TestPingAgentPanic(t *testing.T) {
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: 1,
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/pkg/libcrypto"
|
||||
"github.com/portainer/portainer/pkg/librand"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -200,7 +200,9 @@ func (service *Service) getUnusedPort() int {
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warn().Msg("failed to close tcp connection that checks if port is free")
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("port", port).
|
||||
@@ -213,7 +215,7 @@ func (service *Service) getUnusedPort() int {
|
||||
}
|
||||
|
||||
func randomInt(min, max int) int {
|
||||
return min + rand.Intn(max-min)
|
||||
return min + librand.Intn(max-min)
|
||||
}
|
||||
|
||||
func generateRandomCredentials() (string, string) {
|
||||
|
||||
79
api/chisel/tunnel_test.go
Normal file
79
api/chisel/tunnel_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package chisel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
type testSettingsService struct {
|
||||
dataservices.SettingsService
|
||||
}
|
||||
|
||||
func (s *testSettingsService) Settings() (*portainer.Settings, error) {
|
||||
return &portainer.Settings{
|
||||
EdgeAgentCheckinInterval: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type testStore struct {
|
||||
dataservices.DataStore
|
||||
}
|
||||
|
||||
func (s *testStore) Settings() dataservices.SettingsService {
|
||||
return &testSettingsService{}
|
||||
}
|
||||
|
||||
func TestGetUnusedPort(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
existingTunnels map[portainer.EndpointID]*portainer.TunnelDetails
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "simple case",
|
||||
},
|
||||
{
|
||||
name: "existing tunnels",
|
||||
existingTunnels: map[portainer.EndpointID]*portainer.TunnelDetails{
|
||||
portainer.EndpointID(1): {
|
||||
Port: 53072,
|
||||
},
|
||||
portainer.EndpointID(2): {
|
||||
Port: 63072,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
store := &testStore{}
|
||||
s := NewService(store, nil, nil)
|
||||
s.activeTunnels = tc.existingTunnels
|
||||
port := s.getUnusedPort()
|
||||
|
||||
if port < 49152 || port > 65535 {
|
||||
t.Fatalf("Expected port to be inbetween 49152 and 65535 but got %d", port)
|
||||
}
|
||||
|
||||
for _, tun := range tc.existingTunnels {
|
||||
if tun.Port == port {
|
||||
t.Fatalf("returned port %d already has an existing tunnel", port)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
|
||||
if err == nil {
|
||||
// Ignore error
|
||||
_ = conn.Close()
|
||||
t.Fatalf("expected port %d to be unused", port)
|
||||
} else if !strings.Contains(err.Error(), "connection refused") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
// Service implements the CLIService interface
|
||||
@@ -35,16 +35,9 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(),
|
||||
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
|
||||
NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(),
|
||||
TLS: kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).Bool(),
|
||||
TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(),
|
||||
TLSCacert: kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String(),
|
||||
TLSCert: kingpin.Flag("tlscert", "Path to the TLS certificate file").Default(defaultTLSCertPath).String(),
|
||||
TLSKey: kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).String(),
|
||||
HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(),
|
||||
HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(),
|
||||
SSL: kingpin.Flag("ssl", "Secure Portainer instance using SSL (deprecated)").Default(defaultSSL).Bool(),
|
||||
SSLCert: kingpin.Flag("sslcert", "Path to the SSL certificate used to secure the Portainer instance").String(),
|
||||
SSLKey: kingpin.Flag("sslkey", "Path to the SSL key used to secure the Portainer instance").String(),
|
||||
Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(),
|
||||
SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(),
|
||||
AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(),
|
||||
@@ -60,15 +53,48 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
|
||||
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
|
||||
KubectlShellImage: kingpin.Flag("kubectl-shell-image", "Kubectl shell image").Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String(),
|
||||
PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(),
|
||||
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
|
||||
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
|
||||
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
|
||||
}
|
||||
}
|
||||
|
||||
// ParseFlags parse the CLI flags and return a portainer.Flags struct
|
||||
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
var hasSSLFlag, hasSSLCertFlag, hasSSLKeyFlag bool
|
||||
sslFlag := kingpin.Flag(
|
||||
"ssl",
|
||||
"Secure Portainer instance using SSL (deprecated)",
|
||||
).Default(defaultSSL).IsSetByUser(&hasSSLFlag)
|
||||
ssl := sslFlag.Bool()
|
||||
sslCertFlag := kingpin.Flag(
|
||||
"sslcert",
|
||||
"Path to the SSL certificate used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLCertFlag)
|
||||
sslCert := sslCertFlag.String()
|
||||
sslKeyFlag := kingpin.Flag(
|
||||
"sslkey",
|
||||
"Path to the SSL key used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLKeyFlag)
|
||||
sslKey := sslKeyFlag.String()
|
||||
|
||||
flags := CLIFlags()
|
||||
|
||||
var hasTLSFlag, hasTLSCertFlag, hasTLSKeyFlag bool
|
||||
tlsFlag := kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).IsSetByUser(&hasTLSFlag)
|
||||
flags.TLS = tlsFlag.Bool()
|
||||
tlsCertFlag := kingpin.Flag(
|
||||
"tlscert",
|
||||
"Path to the TLS certificate file",
|
||||
).Default(defaultTLSCertPath).IsSetByUser(&hasTLSCertFlag)
|
||||
flags.TLSCert = tlsCertFlag.String()
|
||||
tlsKeyFlag := kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).IsSetByUser(&hasTLSKeyFlag)
|
||||
flags.TLSKey = tlsKeyFlag.String()
|
||||
flags.TLSCacert = kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String()
|
||||
|
||||
kingpin.Parse()
|
||||
|
||||
if !filepath.IsAbs(*flags.Assets) {
|
||||
@@ -80,11 +106,46 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
*flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets)
|
||||
}
|
||||
|
||||
// If the user didn't provide a tls flag remove the defaults to match previous behaviour
|
||||
if !hasTLSFlag {
|
||||
if !hasTLSCertFlag {
|
||||
*flags.TLSCert = ""
|
||||
}
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
*flags.TLSKey = ""
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslFlag.Model().Name, tlsFlag.Model().Name)
|
||||
|
||||
if !hasTLSFlag {
|
||||
flags.TLS = ssl
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLCertFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslCertFlag.Model().Name, tlsCertFlag.Model().Name)
|
||||
|
||||
if !hasTLSCertFlag {
|
||||
flags.TLSCert = sslCert
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLKeyFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslKeyFlag.Model().Name, tlsKeyFlag.Model().Name)
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
flags.TLSKey = sslKey
|
||||
}
|
||||
}
|
||||
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
// ValidateFlags validates the values of the flags.
|
||||
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
func (Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
displayDeprecationWarnings(flags)
|
||||
|
||||
if err := validateEndpointURL(*flags.EndpointURL); err != nil {
|
||||
@@ -106,10 +167,6 @@ func displayDeprecationWarnings(flags *portainer.CLIFlags) {
|
||||
if *flags.NoAnalytics {
|
||||
log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect")
|
||||
}
|
||||
|
||||
if *flags.SSL {
|
||||
log.Warn().Msg("SSL is enabled by default and there is no need for the --ssl flag, it has been kept to allow migration of instances running a previous version of Portainer with this flag enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpointURL(endpointURL string) error {
|
||||
|
||||
209
api/cli/cli_test.go
Normal file
209
api/cli/cli_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
zerolog "github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOptionParser(t *testing.T) {
|
||||
p := Service{}
|
||||
require.NotNil(t, p)
|
||||
|
||||
a := os.Args
|
||||
defer func() { os.Args = a }()
|
||||
|
||||
os.Args = []string{"portainer", "--edge-compute"}
|
||||
|
||||
opts, err := p.ParseFlags("2.34.5")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.False(t, *opts.HTTPDisabled)
|
||||
require.True(t, *opts.EnableEdgeComputeFeatures)
|
||||
}
|
||||
|
||||
func TestParseTLSFlags(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedTLSFlag bool
|
||||
expectedTLSCertFlag string
|
||||
expectedTLSKeyFlag string
|
||||
expectedLogMessages []string
|
||||
}{
|
||||
{
|
||||
name: "no flags",
|
||||
expectedTLSFlag: false,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only ssl flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only tls flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: defaultTLSCertPath,
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "partial tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags 2",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var logOutput strings.Builder
|
||||
setupLogOutput(t, &logOutput)
|
||||
|
||||
if tc.args == nil {
|
||||
tc.args = []string{"portainer"}
|
||||
}
|
||||
setOsArgs(t, tc.args)
|
||||
|
||||
s := Service{}
|
||||
flags, err := s.ParseFlags("test-version")
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing flags: %v", err)
|
||||
}
|
||||
|
||||
if flags.TLS == nil {
|
||||
t.Fatal("TLS flag was nil")
|
||||
}
|
||||
|
||||
require.Equal(t, tc.expectedTLSFlag, *flags.TLS, "tlsverify flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSCertFlag, *flags.TLSCert, "tlscert flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSKeyFlag, *flags.TLSKey, "tlskey flag didn't match")
|
||||
|
||||
for _, expectedLogMessage := range tc.expectedLogMessages {
|
||||
require.Contains(t, logOutput.String(), expectedLogMessage, "Log didn't contain expected message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setOsArgs(t *testing.T, args []string) {
|
||||
t.Helper()
|
||||
previousArgs := os.Args
|
||||
os.Args = args
|
||||
t.Cleanup(func() {
|
||||
os.Args = previousArgs
|
||||
})
|
||||
}
|
||||
|
||||
func setupLogOutput(t *testing.T, w io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
oldLogger := zerolog.Logger
|
||||
zerolog.Logger = zerolog.Output(w)
|
||||
t.Cleanup(func() {
|
||||
zerolog.Logger = oldLogger
|
||||
})
|
||||
}
|
||||
@@ -4,20 +4,21 @@
|
||||
package cli
|
||||
|
||||
const (
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "/data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "/certs/ca.pem"
|
||||
defaultTLSCertPath = "/certs/cert.pem"
|
||||
defaultTLSKeyPath = "/certs/key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "/data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "/certs/ca.pem"
|
||||
defaultTLSCertPath = "/certs/cert.pem"
|
||||
defaultTLSKeyPath = "/certs/key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
defaultPullLimitCheckDisabled = "false"
|
||||
)
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
package cli
|
||||
|
||||
const (
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "C:\\data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "C:\\certs\\ca.pem"
|
||||
defaultTLSCertPath = "C:\\certs\\cert.pem"
|
||||
defaultTLSKeyPath = "C:\\certs\\key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultSnapshotInterval = "5m"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "C:\\data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "C:\\certs\\ca.pem"
|
||||
defaultTLSCertPath = "C:\\certs\\cert.pem"
|
||||
defaultTLSKeyPath = "C:\\certs\\key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultSnapshotInterval = "5m"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
defaultPullLimitCheckDisabled = "false"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
)
|
||||
|
||||
type pairList []portainer.Pair
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
type pairListBool []portainer.Pair
|
||||
|
||||
// Set implementation for a list of portainer.Pair
|
||||
func (l *pairListBool) Set(value string) error {
|
||||
p := new(portainer.Pair)
|
||||
|
||||
// default to true. example setting=true is equivalent to setting
|
||||
parts := strings.SplitN(value, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
p.Name = parts[0]
|
||||
p.Value = "true"
|
||||
} else {
|
||||
p.Name = parts[0]
|
||||
p.Value = parts[1]
|
||||
}
|
||||
|
||||
*l = append(*l, *p)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implementation for a list of pair
|
||||
func (l *pairListBool) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsCumulative implementation for a list of pair
|
||||
func (l *pairListBool) IsCumulative() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func BoolPairs(s kingpin.Settings) (target *[]portainer.Pair) {
|
||||
target = new([]portainer.Pair)
|
||||
s.SetValue((*pairListBool)(target))
|
||||
return
|
||||
}
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/portainer/portainer/api/kubernetes"
|
||||
kubecli "github.com/portainer/portainer/api/kubernetes/cli"
|
||||
"github.com/portainer/portainer/api/ldap"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/api/oauth"
|
||||
"github.com/portainer/portainer/api/pendingactions"
|
||||
"github.com/portainer/portainer/api/pendingactions/actions"
|
||||
@@ -48,15 +49,18 @@ import (
|
||||
"github.com/portainer/portainer/api/stacks/deployments"
|
||||
"github.com/portainer/portainer/pkg/build"
|
||||
"github.com/portainer/portainer/pkg/featureflags"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/portainer/portainer/pkg/libhelm"
|
||||
libhelmtypes "github.com/portainer/portainer/pkg/libhelm/types"
|
||||
"github.com/portainer/portainer/pkg/libstack/compose"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func initCLI() *portainer.CLIFlags {
|
||||
cliService := &cli.Service{}
|
||||
cliService := cli.Service{}
|
||||
|
||||
flags, err := cliService.ParseFlags(portainer.APIVersion)
|
||||
if err != nil {
|
||||
@@ -80,7 +84,7 @@ func initFileService(dataStorePath string) portainer.FileService {
|
||||
}
|
||||
|
||||
func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService portainer.FileService, shutdownCtx context.Context) dataservices.DataStore {
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey)
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey, *flags.CompactDB)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed creating database connection")
|
||||
}
|
||||
@@ -165,12 +169,12 @@ func checkDBSchemaServerVersionMatch(dbStore dataservices.DataStore, serverVersi
|
||||
return v.SchemaVersion == serverVersion && v.Edition == serverEdition
|
||||
}
|
||||
|
||||
func initKubernetesDeployer(kubernetesTokenCacheManager *kubeproxy.TokenCacheManager, kubernetesClientFactory *kubecli.ClientFactory, dataStore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager, assetsPath string) portainer.KubernetesDeployer {
|
||||
return exec.NewKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, assetsPath)
|
||||
func initKubernetesDeployer(kubernetesTokenCacheManager *kubeproxy.TokenCacheManager, kubernetesClientFactory *kubecli.ClientFactory, dataStore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager) portainer.KubernetesDeployer {
|
||||
return exec.NewKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager)
|
||||
}
|
||||
|
||||
func initHelmPackageManager(assetsPath string) (libhelm.HelmPackageManager, error) {
|
||||
return libhelm.NewHelmPackageManager(libhelm.HelmConfig{BinaryPath: assetsPath})
|
||||
func initHelmPackageManager() (libhelmtypes.HelmPackageManager, error) {
|
||||
return libhelm.NewHelmPackageManager()
|
||||
}
|
||||
|
||||
func initAPIKeyService(datastore dataservices.DataStore) apikey.APIKeyService {
|
||||
@@ -303,8 +307,19 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
|
||||
return generateAndStoreKeyPair(fileService, signatureService)
|
||||
}
|
||||
|
||||
// dbSecretPath build the path to the file that contains the db encryption
|
||||
// secret. Normally in Docker this is built from the static path inside
|
||||
// /run/secrets for example: /run/secrets/<keyFilenameFlag> but for ease of
|
||||
// use outside Docker it also accepts an absolute path
|
||||
func dbSecretPath(keyFilenameFlag string) string {
|
||||
if path.IsAbs(keyFilenameFlag) {
|
||||
return keyFilenameFlag
|
||||
}
|
||||
return path.Join("/run/secrets", keyFilenameFlag)
|
||||
}
|
||||
|
||||
func loadEncryptionSecretKey(keyfilename string) []byte {
|
||||
content, err := os.ReadFile(path.Join("/run/secrets", keyfilename))
|
||||
content, err := os.ReadFile(keyfilename)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Info().Str("filename", keyfilename).Msg("encryption key file not present")
|
||||
@@ -316,6 +331,7 @@ func loadEncryptionSecretKey(keyfilename string) []byte {
|
||||
}
|
||||
|
||||
// return a 32 byte hash of the secret (required for AES)
|
||||
// fips compliant version of this is not implemented in -ce
|
||||
hash := sha256.Sum256(content)
|
||||
|
||||
return hash[:]
|
||||
@@ -328,8 +344,23 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags)
|
||||
}
|
||||
|
||||
trustedOrigins := []string{}
|
||||
if *flags.TrustedOrigins != "" {
|
||||
// validate if the trusted origins are valid urls
|
||||
for _, origin := range strings.Split(*flags.TrustedOrigins, ",") {
|
||||
if !validate.IsTrustedOrigin(origin) {
|
||||
log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.")
|
||||
}
|
||||
|
||||
trustedOrigins = append(trustedOrigins, origin)
|
||||
}
|
||||
}
|
||||
|
||||
// -ce can not ever be run in FIPS mode
|
||||
fips.InitFIPS(false)
|
||||
|
||||
fileService := initFileService(*flags.Data)
|
||||
encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName)
|
||||
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
|
||||
if encryptionKey == nil {
|
||||
log.Info().Msg("proceeding without encryption key")
|
||||
}
|
||||
@@ -362,21 +393,22 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
log.Fatal().Err(err).Msg("failed initializing JWT service")
|
||||
}
|
||||
|
||||
ldapService := &ldap.Service{}
|
||||
ldapService := ldap.Service{}
|
||||
|
||||
oauthService := oauth.NewService()
|
||||
|
||||
gitService := git.NewService(shutdownCtx)
|
||||
|
||||
openAMTService := openamt.NewService()
|
||||
// Setting insecureSkipVerify to true to preserve the old behaviour.
|
||||
openAMTService := openamt.NewService(true)
|
||||
|
||||
cryptoService := &crypto.Service{}
|
||||
cryptoService := crypto.Service{}
|
||||
|
||||
signatureService := initDigitalSignatureService()
|
||||
|
||||
edgeStacksService := edgestacks.NewService(dataStore)
|
||||
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.SSLCert, *flags.SSLKey, fileService, dataStore, shutdownTrigger)
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.TLSCert, *flags.TLSKey, fileService, dataStore, shutdownTrigger)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
@@ -421,7 +453,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
log.Fatal().Err(err).Msg("failed initializing swarm stack manager")
|
||||
}
|
||||
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, *flags.Assets)
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager)
|
||||
|
||||
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory)
|
||||
pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore))
|
||||
@@ -435,9 +467,9 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
|
||||
snapshotService.Start()
|
||||
|
||||
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
|
||||
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService, jwtService)
|
||||
|
||||
helmPackageManager, err := initHelmPackageManager(*flags.Assets)
|
||||
helmPackageManager, err := initHelmPackageManager()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing helm package manager")
|
||||
}
|
||||
@@ -543,6 +575,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
Status: applicationStatus,
|
||||
BindAddress: *flags.Addr,
|
||||
BindAddressHTTPS: *flags.AddrHTTPS,
|
||||
CSP: *flags.CSP,
|
||||
HTTPEnabled: sslDBSettings.HTTPEnabled,
|
||||
AssetsPath: *flags.Assets,
|
||||
DataStore: dataStore,
|
||||
@@ -575,17 +608,19 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
AdminCreationDone: adminCreationDone,
|
||||
PendingActionsService: pendingActionsService,
|
||||
PlatformService: platformService,
|
||||
PullLimitCheckDisabled: *flags.PullLimitCheckDisabled,
|
||||
TrustedOrigins: trustedOrigins,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
configureLogger()
|
||||
setLoggingMode("PRETTY")
|
||||
logs.ConfigureLogger()
|
||||
logs.SetLoggingMode("PRETTY")
|
||||
|
||||
flags := initCLI()
|
||||
|
||||
setLoggingLevel(*flags.LogLevel)
|
||||
setLoggingMode(*flags.LogMode)
|
||||
logs.SetLoggingLevel(*flags.LogLevel)
|
||||
logs.SetLoggingMode(*flags.LogMode)
|
||||
|
||||
for {
|
||||
server := buildServer(flags)
|
||||
|
||||
57
api/cmd/portainer/main_test.go
Normal file
57
api/cmd/portainer/main_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const secretFileName = "secret.txt"
|
||||
|
||||
func createPasswordFile(t *testing.T, secretPath, password string) string {
|
||||
err := os.WriteFile(secretPath, []byte(password), 0600)
|
||||
require.NoError(t, err)
|
||||
return secretPath
|
||||
}
|
||||
|
||||
func TestLoadEncryptionSecretKey(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
secretPath := path.Join(tempDir, secretFileName)
|
||||
|
||||
// first pointing to file that does not exist, gives nil hash (no encryption)
|
||||
encryptionKey := loadEncryptionSecretKey(secretPath)
|
||||
require.Nil(t, encryptionKey)
|
||||
|
||||
// point to a directory instead of a file
|
||||
encryptionKey = loadEncryptionSecretKey(tempDir)
|
||||
require.Nil(t, encryptionKey)
|
||||
|
||||
password := "portainer@1234"
|
||||
createPasswordFile(t, secretPath, password)
|
||||
|
||||
encryptionKey = loadEncryptionSecretKey(secretPath)
|
||||
require.NotNil(t, encryptionKey)
|
||||
// should be 32 bytes for aes256 encryption
|
||||
require.Len(t, encryptionKey, 32)
|
||||
}
|
||||
|
||||
func TestDBSecretPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
keyFilenameFlag string
|
||||
expected string
|
||||
}{
|
||||
{keyFilenameFlag: "secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"},
|
||||
{keyFilenameFlag: "/run/secrets/secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "./secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"},
|
||||
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/secrets/foo/bar/secret.txt"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.expected, dbSecretPath(test.keyFilenameFlag))
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
|
||||
type ReadTransaction interface {
|
||||
GetObject(bucketName string, key []byte, object any) error
|
||||
GetRawBytes(bucketName string, key []byte) ([]byte, error)
|
||||
GetAll(bucketName string, obj any, append func(o any) (any, error)) error
|
||||
GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj any, append func(o any) (any, error)) error
|
||||
KeyExists(bucketName string, key []byte) (bool, error)
|
||||
}
|
||||
|
||||
type Transaction interface {
|
||||
|
||||
@@ -5,10 +5,15 @@ import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/pbkdf2"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
@@ -19,20 +24,32 @@ const (
|
||||
aesGcmHeader = "AES256-GCM" // The encrypted file header
|
||||
aesGcmBlockSize = 1024 * 1024 // 1MB block for aes gcm
|
||||
|
||||
aesGcmFIPSHeader = "FIPS-AES256-GCM"
|
||||
aesGcmFIPSBlockSize = 16 * 1024 * 1024 // 16MB block for aes gcm
|
||||
|
||||
// Argon2 settings
|
||||
// Recommded settings lower memory hardware according to current OWASP recommendations
|
||||
// Recommended settings lower memory hardware according to current OWASP recommendations
|
||||
// Considering some people run portainer on a NAS I think it's prudent not to assume we're on server grade hardware
|
||||
// https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#argon2id
|
||||
argon2MemoryCost = 12 * 1024
|
||||
argon2TimeCost = 3
|
||||
argon2Threads = 1
|
||||
argon2KeyLength = 32
|
||||
|
||||
pbkdf2Iterations = 600_000 // use recommended iterations from https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2 a little overkill for this use
|
||||
pbkdf2SaltLength = 32
|
||||
)
|
||||
|
||||
// AesEncrypt reads from input, encrypts with AES-256 and writes to output. passphrase is used to generate an encryption key
|
||||
func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
if err := aesEncryptGCM(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
if fips.FIPSMode() {
|
||||
if err := aesEncryptGCMFIPS(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := aesEncryptGCM(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -40,14 +57,35 @@ func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
|
||||
// AesDecrypt reads from input, decrypts with AES-256 and returns the reader to read the decrypted content from
|
||||
func AesDecrypt(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func aesDecrypt(input io.Reader, passphrase []byte, fipsMode bool) (io.Reader, error) {
|
||||
// Read file header to determine how it was encrypted
|
||||
inputReader := bufio.NewReader(input)
|
||||
header, err := inputReader.Peek(len(aesGcmHeader))
|
||||
header, err := inputReader.Peek(len(aesGcmFIPSHeader))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading encrypted backup file header: %w", err)
|
||||
}
|
||||
|
||||
if string(header) == aesGcmHeader {
|
||||
if strings.HasPrefix(string(header), aesGcmFIPSHeader) {
|
||||
if !fipsMode {
|
||||
return nil, errors.New("fips encrypted file detected but fips mode is not enabled")
|
||||
}
|
||||
|
||||
reader, err := aesDecryptGCMFIPS(inputReader, passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting file: %w", err)
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(string(header), aesGcmHeader) {
|
||||
if fipsMode {
|
||||
return nil, errors.New("fips mode is enabled but non-fips encrypted file detected")
|
||||
}
|
||||
|
||||
reader, err := aesDecryptGCM(inputReader, passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting file: %w", err)
|
||||
@@ -114,15 +152,14 @@ func aesEncryptGCM(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
break // end of plaintext input
|
||||
}
|
||||
|
||||
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seal encrypts the plaintext using the nonce returning the updated slice.
|
||||
ciphertext = aesgcm.Seal(ciphertext[:0], nonce.Value(), buf[:n], nil)
|
||||
|
||||
_, err = output.Write(ciphertext)
|
||||
if err != nil {
|
||||
if _, err := output.Write(ciphertext); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -183,7 +220,7 @@ func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
break // end of ciphertext
|
||||
}
|
||||
|
||||
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -203,15 +240,138 @@ func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
// aesEncryptGCMFIPS reads from input, encrypts with AES-256 in a fips compliant
|
||||
// way and writes to output. passphrase is used to generate an encryption key.
|
||||
func aesEncryptGCMFIPS(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
salt := make([]byte, pbkdf2SaltLength)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write the header
|
||||
if _, err := output.Write([]byte(aesGcmFIPSHeader)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write nonce and salt to the output file
|
||||
if _, err := output.Write(salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffer for reading plaintext blocks
|
||||
buf := make([]byte, aesGcmFIPSBlockSize)
|
||||
|
||||
// Encrypt plaintext in blocks
|
||||
for {
|
||||
// new random nonce for each block
|
||||
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating gcm: %w", err)
|
||||
}
|
||||
|
||||
n, err := io.ReadFull(input, buf)
|
||||
if n == 0 {
|
||||
break // end of plaintext input
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seal encrypts the plaintext
|
||||
ciphertext := aesgcm.Seal(nil, nil, buf[:n], nil)
|
||||
|
||||
if _, err := output.Write(ciphertext); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// aesDecryptGCMFIPS reads from input, decrypts with AES-256 in a fips compliant
|
||||
// way and returns the reader to read the decrypted content from.
|
||||
func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
// Reader & verify header
|
||||
header := make([]byte, len(aesGcmFIPSHeader))
|
||||
if _, err := io.ReadFull(input, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if string(header) != aesGcmFIPSHeader {
|
||||
return nil, errors.New("invalid header")
|
||||
}
|
||||
|
||||
// Read salt
|
||||
salt := make([]byte, pbkdf2SaltLength)
|
||||
if _, err := io.ReadFull(input, salt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
// Initialize AES cipher block
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize a buffer to store decrypted data
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
// Decrypt the ciphertext in blocks
|
||||
for {
|
||||
// Create GCM mode with the cipher block
|
||||
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read a block of ciphertext from the input reader
|
||||
ciphertextBlock := make([]byte, aesGcmFIPSBlockSize+aesgcm.Overhead())
|
||||
n, err := io.ReadFull(input, ciphertextBlock)
|
||||
if n == 0 {
|
||||
break // end of ciphertext
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt the block of ciphertext
|
||||
plaintext, err := aesgcm.Open(nil, nil, ciphertextBlock[:n], nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.Write(plaintext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
// aesDecryptOFB reads from input, decrypts with AES-256 and returns the reader to a read decrypted content from.
|
||||
// passphrase is used to generate an encryption key.
|
||||
// note: This function used to decrypt files that were encrypted without a header i.e. old archives
|
||||
func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
var emptySalt []byte = make([]byte, 0)
|
||||
|
||||
// making a 32 bytes key that would correspond to AES-256
|
||||
// don't necessarily need a salt, so just kept in empty
|
||||
key, err := scrypt.Key(passphrase, emptySalt, 32768, 8, 1, 32)
|
||||
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -228,3 +388,18 @@ func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// HasEncryptedHeader checks if the data has an encrypted header, note that fips
|
||||
// mode changes this behavior and so will only recognize data encrypted by the
|
||||
// same mode (fips enabled or disabled)
|
||||
func HasEncryptedHeader(data []byte) bool {
|
||||
return hasEncryptedHeader(data, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func hasEncryptedHeader(data []byte, fipsMode bool) bool {
|
||||
if fipsMode {
|
||||
return bytes.HasPrefix(data, []byte(aesGcmFIPSHeader))
|
||||
}
|
||||
|
||||
return bytes.HasPrefix(data, []byte(aesGcmHeader))
|
||||
}
|
||||
|
||||
@@ -1,15 +1,25 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func randBytes(n int) []byte {
|
||||
@@ -17,201 +27,385 @@ func randBytes(n int) []byte {
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
type encryptFunc func(input io.Reader, output io.Writer, passphrase []byte) error
|
||||
type decryptFunc func(input io.Reader, passphrase []byte) (io.Reader, error)
|
||||
|
||||
func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
|
||||
const passphrase = "passphrase"
|
||||
|
||||
tmpdir := t.TempDir()
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc, decryptShouldSucceed bool) {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1024*1024*100 + 523)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
content := randBytes(1024*1024*100 + 523)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
require.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
if !decryptShouldSucceed {
|
||||
require.Error(t, err, "Failed to decrypt file as indicated by decryptShouldSucceed")
|
||||
} else {
|
||||
require.NoError(t, err, "Failed to decrypt file indicated by decryptShouldSucceed")
|
||||
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS, true)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM, true)
|
||||
})
|
||||
|
||||
t.Run("system_fips_mode_public_entry_points", func(t *testing.T) {
|
||||
// use the init mode, public entry points
|
||||
testFunc(t, AesEncrypt, AesDecrypt, true)
|
||||
})
|
||||
|
||||
t.Run("fips_encrypted_file_header_fails_in_non_fips_mode", func(t *testing.T) {
|
||||
// use aesDecrypt which checks the header, confirm that it fails
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, false)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCMFIPS, decrypt, false)
|
||||
})
|
||||
|
||||
t.Run("non_fips_encrypted_file_header_fails_in_fips_mode", func(t *testing.T) {
|
||||
// use aesDecrypt which checks the header, confirm that it fails
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, true)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCM, decrypt, false)
|
||||
})
|
||||
|
||||
t.Run("fips_encrypted_file_fails_in_non_fips_mode", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCM, false)
|
||||
})
|
||||
|
||||
t.Run("non_fips_encrypted_file_with_fips_mode_should_fail", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCMFIPS, false)
|
||||
})
|
||||
|
||||
t.Run("fips_with_base_aesDecrypt", func(t *testing.T) {
|
||||
// maximize coverage, use the base aesDecrypt function with valid fips mode
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, true)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCMFIPS, decrypt, true)
|
||||
})
|
||||
|
||||
t.Run("legacy", func(t *testing.T) {
|
||||
testFunc(t, legacyAesEncrypt, aesDecryptOFB, true)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
|
||||
const passphrase = "A strong passphrase with special characters: !@#$%^&*()_+"
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
content := randBytes(500)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
content := randBytes(500)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withTheSamePasswordSmallFile(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
|
||||
content := randBytes(500)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
content := randBytes(500)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withEmptyPassword(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1024 * 50)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
content := randBytes(1024 * 50)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(""))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(""))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1034)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
content := randBytes(1034)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
_, err = AesDecrypt(encryptedFileReader, []byte("garbage"))
|
||||
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase")
|
||||
_, err = decrypt(encryptedFileReader, []byte("garbage"))
|
||||
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
}
|
||||
|
||||
func legacyAesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var iv [aes.BlockSize]byte
|
||||
stream := cipher.NewOFB(block, iv[:])
|
||||
|
||||
writer := &cipher.StreamWriter{S: stream, W: output}
|
||||
if _, err := io.Copy(writer, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_hasEncryptedHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
fipsMode bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-FIPS mode with valid header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-FIPS mode with FIPS header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with valid header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with non-FIPS header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header",
|
||||
data: []byte("INVALID-HEADER" + "some data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := hasEncryptedHeader(tt.data, tt.fipsMode)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func (service *ECDSAService) CreateSignature(message string) (string, error) {
|
||||
message = service.secret
|
||||
}
|
||||
|
||||
hash := libcrypto.HashFromBytes([]byte(message))
|
||||
hash := libcrypto.InsecureHashFromBytes([]byte(message))
|
||||
|
||||
r, s, err := ecdsa.Sign(rand.Reader, service.privateKey, hash)
|
||||
if err != nil {
|
||||
|
||||
22
api/crypto/ecdsa_test.go
Normal file
22
api/crypto/ecdsa_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateSignature(t *testing.T) {
|
||||
var s = NewECDSAService("secret")
|
||||
|
||||
privKey, pubKey, err := s.GenerateKeyPair()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(privKey), 0)
|
||||
require.Greater(t, len(pubKey), 0)
|
||||
|
||||
m := "test message"
|
||||
r, err := s.CreateSignature(m)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, r, m)
|
||||
require.Greater(t, len(r), 0)
|
||||
}
|
||||
@@ -8,15 +8,16 @@ import (
|
||||
type Service struct{}
|
||||
|
||||
// Hash hashes a string using the bcrypt algorithm
|
||||
func (*Service) Hash(data string) (string, error) {
|
||||
func (Service) Hash(data string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(data), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CompareHashAndData compares a hash to clear data and returns an error if the comparison fails.
|
||||
func (*Service) CompareHashAndData(hash string, data string) error {
|
||||
func (Service) CompareHashAndData(hash string, data string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(data))
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_Hash(t *testing.T) {
|
||||
var s = &Service{}
|
||||
var s = Service{}
|
||||
|
||||
type args struct {
|
||||
hash string
|
||||
@@ -51,3 +53,11 @@ func TestService_Hash(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
s := Service{}
|
||||
|
||||
hash, err := s.Hash("Passw0rd!")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, hash)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ func NewNonce(size int) *Nonce {
|
||||
}
|
||||
|
||||
// NewRandomNonce generates a new initial nonce with the lower byte set to a random value
|
||||
// This ensures there are plenty of nonce values availble before rolling over
|
||||
// This ensures there are plenty of nonce values available before rolling over
|
||||
// Based on ideas from the Secure Programming Cookbook for C and C++ by John Viega, Matt Messier
|
||||
// https://www.oreilly.com/library/view/secure-programming-cookbook/0596003943/ch04s09.html
|
||||
func NewRandomNonce(size int) (*Nonce, error) {
|
||||
|
||||
@@ -4,11 +4,32 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
)
|
||||
|
||||
// CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings
|
||||
func CreateTLSConfiguration() *tls.Config {
|
||||
return &tls.Config{
|
||||
func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
return createTLSConfiguration(fips.FIPSMode(), insecureSkipVerify)
|
||||
}
|
||||
|
||||
func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
if fipsEnabled {
|
||||
return &tls.Config{ //nolint:forbidigo
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
},
|
||||
CurvePreferences: []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521},
|
||||
}
|
||||
}
|
||||
|
||||
return &tls.Config{ //nolint:forbidigo
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
@@ -29,24 +50,33 @@ func CreateTLSConfiguration() *tls.Config {
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
},
|
||||
InsecureSkipVerify: insecureSkipVerify, //nolint:forbidigo
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from memory.
|
||||
func CreateTLSConfigurationFromBytes(caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) {
|
||||
config := CreateTLSConfiguration()
|
||||
config.InsecureSkipVerify = skipServerVerification
|
||||
func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
return createTLSConfigurationFromBytes(fips.FIPSMode(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
|
||||
}
|
||||
|
||||
if !skipClientVerification {
|
||||
func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
if !useTLS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
config := createTLSConfiguration(fipsEnabled, skipServerVerification)
|
||||
|
||||
if !skipClientVerification || fipsEnabled {
|
||||
certificate, err := tls.X509KeyPair(cert, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
if !skipServerVerification {
|
||||
if !skipServerVerification || fipsEnabled {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
config.RootCAs = caCertPool
|
||||
@@ -57,29 +87,36 @@ func CreateTLSConfigurationFromBytes(caCert, cert, key []byte, skipClientVerific
|
||||
|
||||
// CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from disk.
|
||||
func CreateTLSConfigurationFromDisk(caCertPath, certPath, keyPath string, skipServerVerification bool) (*tls.Config, error) {
|
||||
config := CreateTLSConfiguration()
|
||||
config.InsecureSkipVerify = skipServerVerification
|
||||
func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
return createTLSConfigurationFromDisk(fips.FIPSMode(), config)
|
||||
}
|
||||
|
||||
if certPath != "" && keyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
if !config.TLS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := createTLSConfiguration(fipsEnabled, config.TLSSkipVerify)
|
||||
|
||||
if config.TLSCertPath != "" && config.TLSKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.TLSCertPath, config.TLSKeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
if !skipServerVerification && caCertPath != "" {
|
||||
caCert, err := os.ReadFile(caCertPath)
|
||||
if !tlsConfig.InsecureSkipVerify && config.TLSCACertPath != "" { //nolint:forbidigo
|
||||
caCert, err := os.ReadFile(config.TLSCACertPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
config.RootCAs = caCertPool
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
return config, nil
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
87
api/crypto/tls_test.go
Normal file
87
api/crypto/tls_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateTLSConfiguration(t *testing.T) {
|
||||
// InsecureSkipVerify = false
|
||||
config := CreateTLSConfiguration(false)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
|
||||
// InsecureSkipVerify = true
|
||||
config = CreateTLSConfiguration(true)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.True(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFIPS(t *testing.T) {
|
||||
fips := true
|
||||
|
||||
fipsCipherSuites := []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
fipsCurvePreferences := []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521}
|
||||
|
||||
config := createTLSConfiguration(fips, false)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.Equal(t, config.MaxVersion, uint16(tls.VersionTLS13)) //nolint:forbidigo
|
||||
require.Equal(t, config.CipherSuites, fipsCipherSuites) //nolint:forbidigo
|
||||
require.Equal(t, config.CurvePreferences, fipsCurvePreferences) //nolint:forbidigo
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromBytes(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromBytes(false, nil, nil, nil, false, false)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS client/server verifications
|
||||
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, true, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
// Empty TLS
|
||||
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, false, false)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, config)
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromDisk(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{})
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS verifications
|
||||
config, err = CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromDiskFIPS(t *testing.T) {
|
||||
fips := true
|
||||
|
||||
// Skipping TLS verifications cannot be done in FIPS mode
|
||||
config, err := createTLSConfigurationFromDisk(fips, portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
@@ -21,6 +21,9 @@ import (
|
||||
const (
|
||||
DatabaseFileName = "portainer.db"
|
||||
EncryptedDatabaseFileName = "portainer.edb"
|
||||
|
||||
txMaxSize = 65536
|
||||
compactedSuffix = ".compacted"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -35,6 +38,7 @@ type DbConnection struct {
|
||||
InitialMmapSize int
|
||||
EncryptionKey []byte
|
||||
isEncrypted bool
|
||||
Compact bool
|
||||
|
||||
*bolt.DB
|
||||
}
|
||||
@@ -132,13 +136,8 @@ func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
|
||||
func (connection *DbConnection) Open() error {
|
||||
log.Info().Str("filename", connection.GetDatabaseFileName()).Msg("loading PortainerDB")
|
||||
|
||||
// Now we open the db
|
||||
databasePath := connection.GetDatabaseFilePath()
|
||||
|
||||
db, err := bolt.Open(databasePath, 0600, &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
})
|
||||
db, err := bolt.Open(databasePath, 0600, connection.boltOptions(connection.Compact))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -147,6 +146,24 @@ func (connection *DbConnection) Open() error {
|
||||
db.MaxBatchDelay = connection.MaxBatchDelay
|
||||
connection.DB = db
|
||||
|
||||
if connection.Compact {
|
||||
log.Info().Msg("compacting database")
|
||||
if err := connection.compact(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to compact database")
|
||||
|
||||
// Close the read-only database and re-open in read-write mode
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after failed compaction")
|
||||
}
|
||||
|
||||
connection.Compact = false
|
||||
|
||||
return connection.Open()
|
||||
} else {
|
||||
log.Info().Msg("database compaction completed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -244,6 +261,32 @@ func (connection *DbConnection) GetObject(bucketName string, key []byte, object
|
||||
})
|
||||
}
|
||||
|
||||
func (connection *DbConnection) GetRawBytes(bucketName string, key []byte) ([]byte, error) {
|
||||
var value []byte
|
||||
|
||||
err := connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
value, err = tx.GetRawBytes(bucketName, key)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return value, err
|
||||
}
|
||||
|
||||
func (connection *DbConnection) KeyExists(bucketName string, key []byte) (bool, error) {
|
||||
var exists bool
|
||||
|
||||
err := connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
exists, err = tx.KeyExists(bucketName, key)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (connection *DbConnection) getEncryptionKey() []byte {
|
||||
if !connection.isEncrypted {
|
||||
return nil
|
||||
@@ -386,3 +429,48 @@ func (connection *DbConnection) RestoreMetadata(s map[string]any) error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// compact attempts to compact the database and replace it iff it succeeds
|
||||
func (connection *DbConnection) compact() (err error) {
|
||||
compactedPath := connection.GetDatabaseFilePath() + compactedSuffix
|
||||
|
||||
if err := os.Remove(compactedPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("failure to remove an existing compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB, err := bolt.Open(compactedPath, 0o600, connection.boltOptions(false))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failure to create the compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB.MaxBatchSize = connection.MaxBatchSize
|
||||
compactedDB.MaxBatchDelay = connection.MaxBatchDelay
|
||||
|
||||
if err := bolt.Compact(compactedDB, connection.DB, txMaxSize); err != nil {
|
||||
return fmt.Errorf("failure to compact the database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := os.Rename(compactedPath, connection.GetDatabaseFilePath()); err != nil {
|
||||
return fmt.Errorf("failure to move the compacted database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after compaction")
|
||||
}
|
||||
|
||||
connection.DB = compactedDB
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (connection *DbConnection) boltOptions(readOnly bool) *bolt.Options {
|
||||
return &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
FreelistType: bolt.FreelistMapType,
|
||||
NoFreelistSync: true,
|
||||
ReadOnly: readOnly,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,11 @@ import (
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
@@ -119,3 +123,59 @@ func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBCompaction(t *testing.T) {
|
||||
db := &DbConnection{Path: t.TempDir()}
|
||||
|
||||
err := db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *bbolt.Tx) error {
|
||||
b, err := tx.CreateBucketIfNotExists([]byte("testbucket"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Put([]byte("key"), []byte("value"))
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reopen the DB to trigger compaction
|
||||
db.Compact = true
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the data is still there
|
||||
err = db.View(func(tx *bbolt.Tx) error {
|
||||
b := tx.Bucket([]byte("testbucket"))
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := b.Get([]byte("key"))
|
||||
require.Equal(t, []byte("value"), val)
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Failures
|
||||
compactedPath := db.GetDatabaseFilePath() + compactedSuffix
|
||||
err = os.Mkdir(compactedPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := os.Create(filesystem.JoinPaths(compactedPath, "somefile"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/segmentio/encoding/json"
|
||||
@@ -65,18 +63,18 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
|
||||
// https://gist.github.com/atoponce/07d8d4c833873be2f68c34f9afc5a78a#symmetric-encryption
|
||||
|
||||
func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
|
||||
block, _ := aes.NewCipher(passphrase)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
block, err := aes.NewCipher(passphrase)
|
||||
if err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
// NewGCMWithRandomNonce in go 1.24 handles setting up the nonce and adding it to the encrypted output
|
||||
gcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, plaintext, nil), nil
|
||||
return gcm.Seal(nil, nil, plaintext, nil), nil
|
||||
}
|
||||
|
||||
func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err error) {
|
||||
@@ -89,19 +87,17 @@ func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err err
|
||||
return encrypted, errors.Wrap(err, "Error creating cypher block")
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
// NewGCMWithRandomNonce in go 1.24 handles reading the nonce from the encrypted input for us
|
||||
gcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating GCM")
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(encrypted) < nonceSize {
|
||||
if len(encrypted) < gcm.NonceSize() {
|
||||
return encrypted, errEncryptedStringTooShort
|
||||
}
|
||||
|
||||
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
|
||||
|
||||
plaintextByte, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
|
||||
plaintextByte, err = gcm.Open(nil, nil, encrypted, nil)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error decrypting text")
|
||||
}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
package boltdb
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonobject = `{"LogoURL":"","BlackListedLabels":[],"AuthenticationMethod":1,"InternalAuthSettings": {"RequiredPasswordLength": 12}"LDAPSettings":{"AnonymousMode":true,"ReaderDN":"","URL":"","TLSConfig":{"TLS":false,"TLSSkipVerify":false},"StartTLS":false,"SearchSettings":[{"BaseDN":"","Filter":"","UserNameAttribute":""}],"GroupSearchSettings":[{"GroupBaseDN":"","GroupFilter":"","GroupAttribute":""}],"AutoCreateUsers":true},"OAuthSettings":{"ClientID":"","AccessTokenURI":"","AuthorizationURI":"","ResourceURI":"","RedirectURI":"","UserIdentifier":"","Scopes":"","OAuthAutoCreateUsers":false,"DefaultTeamID":0,"SSO":true,"LogoutURI":"","KubeSecretKey":"j0zLVtY/lAWBk62ByyF0uP80SOXaitsABP0TTJX8MhI="},"OpenAMTConfiguration":{"Enabled":false,"MPSServer":"","MPSUser":"","MPSPassword":"","MPSToken":"","CertFileContent":"","CertFileName":"","CertFilePassword":"","DomainName":""},"FeatureFlagSettings":{},"SnapshotInterval":"5m","TemplatesURL":"https://raw.githubusercontent.com/portainer/templates/master/templates-2.0.json","EdgeAgentCheckinInterval":5,"EnableEdgeComputeFeatures":false,"UserSessionTimeout":"8h","KubeconfigExpiry":"0","EnableTelemetry":true,"HelmRepositoryURL":"https://kubernetes.github.io/ingress-nginx","KubectlShellImage":"portainer/kubectl-shell","DisplayDonationHeader":false,"DisplayExternalContributors":false,"EnableHostManagementFeatures":false,"AllowVolumeBrowserForRegularUsers":false,"AllowBindMountsForRegularUsers":false,"AllowPrivilegedModeForRegularUsers":false,"AllowHostNamespaceForRegularUsers":false,"AllowStackManagementForRegularUsers":false,"AllowDeviceMappingForRegularUsers":false,"AllowContainerCapabilitiesForRegularUsers":false}`
|
||||
jsonobject = `{"LogoURL":"","BlackListedLabels":[],"AuthenticationMethod":1,"InternalAuthSettings": {"RequiredPasswordLength": 12}"LDAPSettings":{"AnonymousMode":true,"ReaderDN":"","URL":"","TLSConfig":{"TLS":false,"TLSSkipVerify":false},"StartTLS":false,"SearchSettings":[{"BaseDN":"","Filter":"","UserNameAttribute":""}],"GroupSearchSettings":[{"GroupBaseDN":"","GroupFilter":"","GroupAttribute":""}],"AutoCreateUsers":true},"OAuthSettings":{"ClientID":"","AccessTokenURI":"","AuthorizationURI":"","ResourceURI":"","RedirectURI":"","UserIdentifier":"","Scopes":"","OAuthAutoCreateUsers":false,"DefaultTeamID":0,"SSO":true,"LogoutURI":"","KubeSecretKey":"j0zLVtY/lAWBk62ByyF0uP80SOXaitsABP0TTJX8MhI="},"OpenAMTConfiguration":{"Enabled":false,"MPSServer":"","MPSUser":"","MPSPassword":"","MPSToken":"","CertFileContent":"","CertFileName":"","CertFilePassword":"","DomainName":""},"FeatureFlagSettings":{},"SnapshotInterval":"5m","TemplatesURL":"https://raw.githubusercontent.com/portainer/templates/master/templates-2.0.json","EdgeAgentCheckinInterval":5,"EnableEdgeComputeFeatures":false,"UserSessionTimeout":"8h","KubeconfigExpiry":"0","EnableTelemetry":true,"HelmRepositoryURL":"https://charts.bitnami.com/bitnami","KubectlShellImage":"portainer/kubectl-shell","DisplayDonationHeader":false,"DisplayExternalContributors":false,"EnableHostManagementFeatures":false,"AllowVolumeBrowserForRegularUsers":false,"AllowBindMountsForRegularUsers":false,"AllowPrivilegedModeForRegularUsers":false,"AllowHostNamespaceForRegularUsers":false,"AllowStackManagementForRegularUsers":false,"AllowDeviceMappingForRegularUsers":false,"AllowContainerCapabilitiesForRegularUsers":false}`
|
||||
passphrase = "my secret key"
|
||||
)
|
||||
|
||||
@@ -160,7 +167,7 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
|
||||
}
|
||||
|
||||
key := secretToEncryptionKey(passphrase)
|
||||
conn := DbConnection{EncryptionKey: key}
|
||||
conn := DbConnection{EncryptionKey: key, isEncrypted: true}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
|
||||
@@ -175,3 +182,94 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NonceSources(t *testing.T) {
|
||||
// ensure that the new go 1.24 NewGCMWithRandomNonce works correctly with
|
||||
// the old way of creating and including the nonce
|
||||
|
||||
encryptOldFn := func(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
|
||||
block, _ := aes.NewCipher(passphrase)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
decryptOldFn := func(encrypted []byte, passphrase []byte) (plaintext []byte, err error) {
|
||||
block, err := aes.NewCipher(passphrase)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating cypher block")
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating GCM")
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(encrypted) < nonceSize {
|
||||
return encrypted, errEncryptedStringTooShort
|
||||
}
|
||||
|
||||
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
|
||||
|
||||
plaintext, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error decrypting text")
|
||||
}
|
||||
|
||||
return plaintext, err
|
||||
}
|
||||
|
||||
encryptNewFn := encrypt
|
||||
decryptNewFn := decrypt
|
||||
|
||||
passphrase := make([]byte, 32)
|
||||
_, err := io.ReadFull(rand.Reader, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
junk := make([]byte, 1024)
|
||||
_, err = io.ReadFull(rand.Reader, junk)
|
||||
require.NoError(t, err)
|
||||
|
||||
junkEnc := make([]byte, base64.StdEncoding.EncodedLen(len(junk)))
|
||||
base64.StdEncoding.Encode(junkEnc, junk)
|
||||
|
||||
cases := [][]byte{
|
||||
[]byte("test"),
|
||||
[]byte("35"),
|
||||
[]byte("9ca4a1dd-a439-4593-b386-a7dfdc2e9fc6"),
|
||||
[]byte(jsonobject),
|
||||
passphrase,
|
||||
junk,
|
||||
junkEnc,
|
||||
}
|
||||
|
||||
for _, plain := range cases {
|
||||
var enc, dec []byte
|
||||
var err error
|
||||
|
||||
enc, err = encryptOldFn(plain, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
dec, err = decryptNewFn(enc, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, plain, dec)
|
||||
|
||||
enc, err = encryptNewFn(plain, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
dec, err = decryptOldFn(enc, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, plain, dec)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
dserrors "github.com/portainer/portainer/api/dataservices/errors"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
@@ -31,6 +32,33 @@ func (tx *DbTransaction) GetObject(bucketName string, key []byte, object any) er
|
||||
return tx.conn.UnmarshalObject(value, object)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) GetRawBytes(bucketName string, key []byte) ([]byte, error) {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
value := bucket.Get(key)
|
||||
if value == nil {
|
||||
return nil, fmt.Errorf("%w (bucket=%s, key=%s)", dserrors.ErrObjectNotFound, bucketName, keyToString(key))
|
||||
}
|
||||
|
||||
if tx.conn.getEncryptionKey() != nil {
|
||||
var err error
|
||||
|
||||
if value, err = decrypt(value, tx.conn.getEncryptionKey()); err != nil {
|
||||
return value, errors.Wrap(err, "Failed decrypting object")
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) KeyExists(bucketName string, key []byte) (bool, error) {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
value := bucket.Get(key)
|
||||
|
||||
return value != nil, nil
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) UpdateObject(bucketName string, key []byte, object any) error {
|
||||
data, err := tx.conn.MarshalObject(object)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,11 +8,12 @@ import (
|
||||
)
|
||||
|
||||
// NewDatabase should use config options to return a connection to the requested database
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte) (connection portainer.Connection, err error) {
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte, compact bool) (connection portainer.Connection, err error) {
|
||||
if storeType == "boltdb" {
|
||||
return &boltdb.DbConnection{
|
||||
Path: storePath,
|
||||
EncryptionKey: encryptionKey,
|
||||
Compact: compact,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ import (
|
||||
type BaseCRUD[T any, I constraints.Integer] interface {
|
||||
Create(element *T) error
|
||||
Read(ID I) (*T, error)
|
||||
ReadAll() ([]T, error)
|
||||
Exists(ID I) (bool, error)
|
||||
ReadAll(predicates ...func(T) bool) ([]T, error)
|
||||
Update(ID I, element *T) error
|
||||
Delete(ID I) error
|
||||
}
|
||||
@@ -42,12 +43,26 @@ func (service BaseDataService[T, I]) Read(ID I) (*T, error) {
|
||||
})
|
||||
}
|
||||
|
||||
func (service BaseDataService[T, I]) ReadAll() ([]T, error) {
|
||||
func (service BaseDataService[T, I]) Exists(ID I) (bool, error) {
|
||||
var exists bool
|
||||
|
||||
err := service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
exists, err = service.Tx(tx).Exists(ID)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// ReadAll retrieves all the elements that satisfy all the provided predicates.
|
||||
func (service BaseDataService[T, I]) ReadAll(predicates ...func(T) bool) ([]T, error) {
|
||||
var collection = make([]T, 0)
|
||||
|
||||
return collection, service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
collection, err = service.Tx(tx).ReadAll()
|
||||
collection, err = service.Tx(tx).ReadAll(predicates...)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
92
api/dataservices/base_test.go
Normal file
92
api/dataservices/base_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package dataservices
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testObject struct {
|
||||
ID int
|
||||
Value int
|
||||
}
|
||||
|
||||
type mockConnection struct {
|
||||
store map[int]testObject
|
||||
|
||||
portainer.Connection
|
||||
}
|
||||
|
||||
func (m mockConnection) UpdateObject(bucket string, key []byte, value interface{}) error {
|
||||
obj := value.(*testObject)
|
||||
|
||||
m.store[obj.ID] = *obj
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConnection) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error {
|
||||
for _, v := range m.store {
|
||||
if _, err := appendFn(&v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConnection) UpdateTx(fn func(portainer.Transaction) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m mockConnection) ViewTx(fn func(portainer.Transaction) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m mockConnection) ConvertToKey(v int) []byte {
|
||||
return []byte(strconv.Itoa(v))
|
||||
}
|
||||
|
||||
func TestReadAll(t *testing.T) {
|
||||
service := BaseDataService[testObject, int]{
|
||||
Bucket: "testBucket",
|
||||
Connection: mockConnection{store: make(map[int]testObject)},
|
||||
}
|
||||
|
||||
data := []testObject{
|
||||
{ID: 1, Value: 1},
|
||||
{ID: 2, Value: 2},
|
||||
{ID: 3, Value: 3},
|
||||
{ID: 4, Value: 4},
|
||||
{ID: 5, Value: 5},
|
||||
}
|
||||
|
||||
for _, item := range data {
|
||||
err := service.Update(item.ID, &item)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// ReadAll without predicates
|
||||
result, err := service.ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := append([]testObject{}, data...)
|
||||
|
||||
require.ElementsMatch(t, expected, result)
|
||||
|
||||
// ReadAll with predicates
|
||||
hasLowID := func(obj testObject) bool { return obj.ID < 3 }
|
||||
isEven := func(obj testObject) bool { return obj.Value%2 == 0 }
|
||||
|
||||
result, err = service.ReadAll(hasLowID, isEven)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected = slicesx.Filter(expected, hasLowID)
|
||||
expected = slicesx.Filter(expected, isEven)
|
||||
|
||||
require.ElementsMatch(t, expected, result)
|
||||
}
|
||||
@@ -28,13 +28,38 @@ func (service BaseDataServiceTx[T, I]) Read(ID I) (*T, error) {
|
||||
return &element, nil
|
||||
}
|
||||
|
||||
func (service BaseDataServiceTx[T, I]) ReadAll() ([]T, error) {
|
||||
func (service BaseDataServiceTx[T, I]) Exists(ID I) (bool, error) {
|
||||
identifier := service.Connection.ConvertToKey(int(ID))
|
||||
|
||||
return service.Tx.KeyExists(service.Bucket, identifier)
|
||||
}
|
||||
|
||||
// ReadAll retrieves all the elements that satisfy all the provided predicates.
|
||||
func (service BaseDataServiceTx[T, I]) ReadAll(predicates ...func(T) bool) ([]T, error) {
|
||||
var collection = make([]T, 0)
|
||||
|
||||
if len(predicates) == 0 {
|
||||
return collection, service.Tx.GetAll(
|
||||
service.Bucket,
|
||||
new(T),
|
||||
AppendFn(&collection),
|
||||
)
|
||||
}
|
||||
|
||||
filterFn := func(element T) bool {
|
||||
for _, p := range predicates {
|
||||
if !p(element) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return collection, service.Tx.GetAll(
|
||||
service.Bucket,
|
||||
new(T),
|
||||
AppendFn(&collection),
|
||||
FilterFn(&collection, filterFn),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -28,13 +28,12 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).Create(customTemplate)
|
||||
})
|
||||
}
|
||||
|
||||
19
api/dataservices/customtemplate/customtemplate_test.go
Normal file
19
api/dataservices/customtemplate/customtemplate_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreate(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
|
||||
e, err := ds.CustomTemplate().Read(1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, portainer.CustomTemplateID(1), e.ID)
|
||||
}
|
||||
31
api/dataservices/customtemplate/tx.go
Normal file
31
api/dataservices/customtemplate/tx.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package customtemplate
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
// Service represents a service for managing custom template data.
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]{
|
||||
Bucket: BucketName,
|
||||
Connection: service.Connection,
|
||||
Tx: tx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service ServiceTx) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Tx.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
||||
28
api/dataservices/customtemplate/tx_test.go
Normal file
28
api/dataservices/customtemplate/tx_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreateTx(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1})
|
||||
}))
|
||||
|
||||
var template *portainer.CustomTemplate
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
template, err = tx.CustomTemplate().Read(1)
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
|
||||
}
|
||||
@@ -17,11 +17,29 @@ func (service ServiceTx) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFun
|
||||
}
|
||||
|
||||
func (service ServiceTx) Create(group *portainer.EdgeGroup) error {
|
||||
return service.Tx.CreateObject(
|
||||
es := group.Endpoints
|
||||
group.Endpoints = nil // Clear deprecated field
|
||||
|
||||
err := service.Tx.CreateObject(
|
||||
BucketName,
|
||||
func(id uint64) (int, any) {
|
||||
group.ID = portainer.EdgeGroupID(id)
|
||||
return int(group.ID), group
|
||||
},
|
||||
)
|
||||
|
||||
group.Endpoints = es // Restore endpoints after create
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (service ServiceTx) Update(ID portainer.EdgeGroupID, group *portainer.EdgeGroup) error {
|
||||
es := group.Endpoints
|
||||
group.Endpoints = nil // Clear deprecated field
|
||||
|
||||
err := service.BaseDataServiceTx.Update(ID, group)
|
||||
|
||||
group.Endpoints = es // Restore endpoints after update
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
50
api/dataservices/edgestack/edgestack_test.go
Normal file
50
api/dataservices/edgestack/edgestack_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package edgestack
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn, func(portainer.Transaction, portainer.EdgeStackID) {})
|
||||
require.NoError(t, err)
|
||||
|
||||
const edgeStackID = 1
|
||||
edgeStack := &portainer.EdgeStack{
|
||||
ID: edgeStackID,
|
||||
Name: "Test Stack",
|
||||
}
|
||||
|
||||
err = service.Create(edgeStackID, edgeStack)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = service.UpdateEdgeStackFunc(edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.Name = "Updated Stack"
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedStack, err := service.EdgeStack(edgeStackID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Stack", updatedStack.Name)
|
||||
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.UpdateEdgeStackFuncTx(tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.Name = "Updated Stack Again"
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedStack, err = service.EdgeStack(edgeStackID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Stack Again", updatedStack.Name)
|
||||
}
|
||||
89
api/dataservices/edgestackstatus/edgestackstatus.go
Normal file
89
api/dataservices/edgestackstatus/edgestackstatus.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package edgestackstatus
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
var _ dataservices.EdgeStackStatusService = &Service{}
|
||||
|
||||
const BucketName = "edge_stack_status"
|
||||
|
||||
type Service struct {
|
||||
conn portainer.Connection
|
||||
}
|
||||
|
||||
func (service *Service) BucketName() string {
|
||||
return BucketName
|
||||
}
|
||||
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
if err := connection.SetServiceName(BucketName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Service{conn: connection}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
service: s,
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Create(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Create(edgeStackID, endpointID, status)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Read(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) (*portainer.EdgeStackStatusForEnv, error) {
|
||||
var element *portainer.EdgeStackStatusForEnv
|
||||
|
||||
return element, s.conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
element, err = s.Tx(tx).Read(edgeStackID, endpointID)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) ReadAll(edgeStackID portainer.EdgeStackID) ([]portainer.EdgeStackStatusForEnv, error) {
|
||||
var collection = make([]portainer.EdgeStackStatusForEnv, 0)
|
||||
|
||||
return collection, s.conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
collection, err = s.Tx(tx).ReadAll(edgeStackID)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Update(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Update(edgeStackID, endpointID, status)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Delete(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Delete(edgeStackID, endpointID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) DeleteAll(edgeStackID portainer.EdgeStackID) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).DeleteAll(edgeStackID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Clear(edgeStackID portainer.EdgeStackID, relatedEnvironmentsIDs []portainer.EndpointID) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Clear(edgeStackID, relatedEnvironmentsIDs)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) key(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) []byte {
|
||||
return append(s.conn.ConvertToKey(int(edgeStackID)), s.conn.ConvertToKey(int(endpointID))...)
|
||||
}
|
||||
95
api/dataservices/edgestackstatus/tx.go
Normal file
95
api/dataservices/edgestackstatus/tx.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package edgestackstatus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
var _ dataservices.EdgeStackStatusService = &Service{}
|
||||
|
||||
type ServiceTx struct {
|
||||
service *Service
|
||||
tx portainer.Transaction
|
||||
}
|
||||
|
||||
func (service ServiceTx) Create(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
identifier := service.service.key(edgeStackID, endpointID)
|
||||
return service.tx.CreateObjectWithStringId(BucketName, identifier, status)
|
||||
}
|
||||
|
||||
func (s ServiceTx) Read(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) (*portainer.EdgeStackStatusForEnv, error) {
|
||||
var status portainer.EdgeStackStatusForEnv
|
||||
identifier := s.service.key(edgeStackID, endpointID)
|
||||
|
||||
if err := s.tx.GetObject(BucketName, identifier, &status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
func (s ServiceTx) ReadAll(edgeStackID portainer.EdgeStackID) ([]portainer.EdgeStackStatusForEnv, error) {
|
||||
keyPrefix := s.service.conn.ConvertToKey(int(edgeStackID))
|
||||
|
||||
statuses := make([]portainer.EdgeStackStatusForEnv, 0)
|
||||
|
||||
if err := s.tx.GetAllWithKeyPrefix(BucketName, keyPrefix, &portainer.EdgeStackStatusForEnv{}, dataservices.AppendFn(&statuses)); err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve EdgeStackStatus for EdgeStack %d: %w", edgeStackID, err)
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
identifier := s.service.key(edgeStackID, endpointID)
|
||||
return s.tx.UpdateObject(BucketName, identifier, status)
|
||||
}
|
||||
|
||||
func (s ServiceTx) Delete(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) error {
|
||||
identifier := s.service.key(edgeStackID, endpointID)
|
||||
return s.tx.DeleteObject(BucketName, identifier)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteAll(edgeStackID portainer.EdgeStackID) error {
|
||||
keyPrefix := s.service.conn.ConvertToKey(int(edgeStackID))
|
||||
|
||||
statuses := make([]portainer.EdgeStackStatusForEnv, 0)
|
||||
|
||||
if err := s.tx.GetAllWithKeyPrefix(BucketName, keyPrefix, &portainer.EdgeStackStatusForEnv{}, dataservices.AppendFn(&statuses)); err != nil {
|
||||
return fmt.Errorf("unable to retrieve EdgeStackStatus for EdgeStack %d: %w", edgeStackID, err)
|
||||
}
|
||||
|
||||
for _, status := range statuses {
|
||||
if err := s.tx.DeleteObject(BucketName, s.service.key(edgeStackID, status.EndpointID)); err != nil {
|
||||
return fmt.Errorf("unable to delete EdgeStackStatus for EdgeStack %d and Endpoint %d: %w", edgeStackID, status.EndpointID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s ServiceTx) Clear(edgeStackID portainer.EdgeStackID, relatedEnvironmentsIDs []portainer.EndpointID) error {
|
||||
for _, envID := range relatedEnvironmentsIDs {
|
||||
existingStatus, err := s.Read(edgeStackID, envID)
|
||||
if err != nil && !dataservices.IsErrObjectNotFound(err) {
|
||||
return fmt.Errorf("unable to retrieve status for environment %d: %w", envID, err)
|
||||
}
|
||||
|
||||
var deploymentInfo portainer.StackDeploymentInfo
|
||||
if existingStatus != nil {
|
||||
deploymentInfo = existingStatus.DeploymentInfo
|
||||
}
|
||||
|
||||
if err := s.Update(edgeStackID, envID, &portainer.EdgeStackStatusForEnv{
|
||||
EndpointID: envID,
|
||||
Status: []portainer.EdgeStackDeploymentStatus{},
|
||||
DeploymentInfo: deploymentInfo,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// BucketName represents the name of the bucket where this service stores data.
|
||||
@@ -16,7 +14,6 @@ const BucketName = "endpoint_relations"
|
||||
// Service represents a service for managing environment(endpoint) relation data.
|
||||
type Service struct {
|
||||
connection portainer.Connection
|
||||
updateStackFn func(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
|
||||
updateStackFnTx func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
|
||||
endpointRelationsCache []portainer.EndpointRelation
|
||||
mu sync.Mutex
|
||||
@@ -29,10 +26,8 @@ func (service *Service) BucketName() string {
|
||||
}
|
||||
|
||||
func (service *Service) RegisterUpdateStackFunction(
|
||||
updateFunc func(portainer.EdgeStackID, func(*portainer.EdgeStack)) error,
|
||||
updateFuncTx func(portainer.Transaction, portainer.EdgeStackID, func(*portainer.EdgeStack)) error,
|
||||
) {
|
||||
service.updateStackFn = updateFunc
|
||||
service.updateStackFnTx = updateFuncTx
|
||||
}
|
||||
|
||||
@@ -91,106 +86,26 @@ func (service *Service) Create(endpointRelation *portainer.EndpointRelation) err
|
||||
|
||||
// UpdateEndpointRelation updates an Environment(Endpoint) relation object
|
||||
func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error {
|
||||
previousRelationState, _ := service.EndpointRelation(endpointID)
|
||||
|
||||
identifier := service.connection.ConvertToKey(int(endpointID))
|
||||
err := service.connection.UpdateObject(BucketName, identifier, endpointRelation)
|
||||
cache.Del(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updatedRelationState, _ := service.EndpointRelation(endpointID)
|
||||
|
||||
service.mu.Lock()
|
||||
service.endpointRelationsCache = nil
|
||||
service.mu.Unlock()
|
||||
|
||||
service.updateEdgeStacksAfterRelationChange(previousRelationState, updatedRelationState)
|
||||
|
||||
return nil
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).UpdateEndpointRelation(endpointID, endpointRelation)
|
||||
})
|
||||
}
|
||||
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
return service.connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID)
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStack)
|
||||
})
|
||||
}
|
||||
|
||||
func (service *Service) RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
return service.connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).RemoveEndpointRelationsForEdgeStack(endpointIDs, edgeStackID)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteEndpointRelation deletes an Environment(Endpoint) relation object
|
||||
func (service *Service) DeleteEndpointRelation(endpointID portainer.EndpointID) error {
|
||||
deletedRelation, _ := service.EndpointRelation(endpointID)
|
||||
|
||||
identifier := service.connection.ConvertToKey(int(endpointID))
|
||||
err := service.connection.DeleteObject(BucketName, identifier)
|
||||
cache.Del(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
service.mu.Lock()
|
||||
service.endpointRelationsCache = nil
|
||||
service.mu.Unlock()
|
||||
|
||||
service.updateEdgeStacksAfterRelationChange(deletedRelation, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service *Service) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) {
|
||||
relations, _ := service.EndpointRelations()
|
||||
|
||||
stacksToUpdate := map[portainer.EdgeStackID]bool{}
|
||||
|
||||
if previousRelationState != nil {
|
||||
for stackId, enabled := range previousRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the updated relation state
|
||||
// = stack has been removed for this relation
|
||||
// or this relation has been deleted
|
||||
if enabled && (updatedRelationState == nil || !updatedRelationState.EdgeStacks[stackId]) {
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updatedRelationState != nil {
|
||||
for stackId, enabled := range updatedRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the previous relation state
|
||||
// = stack has been added for this relation
|
||||
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for each stack referenced by the updated relation
|
||||
// list how many time this stack is referenced in all relations
|
||||
// in order to update the stack deployments count
|
||||
for refStackId, refStackEnabled := range stacksToUpdate {
|
||||
if !refStackEnabled {
|
||||
continue
|
||||
}
|
||||
|
||||
numDeployments := 0
|
||||
|
||||
for _, r := range relations {
|
||||
for sId, enabled := range r.EdgeStacks {
|
||||
if enabled && sId == refStackId {
|
||||
numDeployments += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := service.updateStackFn(refStackId, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments = numDeployments
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
}
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).DeleteEndpointRelation(endpointID)
|
||||
})
|
||||
}
|
||||
|
||||
140
api/dataservices/endpointrelation/endpointrelation_test.go
Normal file
140
api/dataservices/endpointrelation/endpointrelation_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package endpointrelation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestack"
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdateRelation(t *testing.T) {
|
||||
const endpointID = 1
|
||||
const edgeStackID1 = 1
|
||||
const edgeStackID2 = 2
|
||||
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
updateStackFnTxCalled := false
|
||||
|
||||
edgeStacks := make(map[portainer.EdgeStackID]portainer.EdgeStack)
|
||||
edgeStacks[edgeStackID1] = portainer.EdgeStack{ID: edgeStackID1}
|
||||
edgeStacks[edgeStackID2] = portainer.EdgeStack{ID: edgeStackID2}
|
||||
|
||||
service.RegisterUpdateStackFunction(func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error {
|
||||
updateStackFnTxCalled = true
|
||||
|
||||
s, ok := edgeStacks[ID]
|
||||
require.True(t, ok)
|
||||
|
||||
updateFunc(&s)
|
||||
edgeStacks[ID] = s
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Nil relation
|
||||
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.UpdateEndpointRelation(endpointID, nil)
|
||||
_, cacheKeyExists := cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
|
||||
// Add a relation to two edge stacks
|
||||
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.UpdateEndpointRelation(endpointID, &portainer.EndpointRelation{
|
||||
EndpointID: endpointID,
|
||||
EdgeStacks: map[portainer.EdgeStackID]bool{
|
||||
edgeStackID1: true,
|
||||
edgeStackID2: true,
|
||||
},
|
||||
})
|
||||
_, cacheKeyExists = cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
require.Equal(t, 1, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 1, edgeStacks[edgeStackID2].NumDeployments)
|
||||
|
||||
// Remove a relation to one edge stack
|
||||
|
||||
updateStackFnTxCalled = false
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.UpdateEndpointRelation(endpointID, &portainer.EndpointRelation{
|
||||
EndpointID: endpointID,
|
||||
EdgeStacks: map[portainer.EdgeStackID]bool{
|
||||
2: true,
|
||||
},
|
||||
})
|
||||
_, cacheKeyExists = cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 1, edgeStacks[edgeStackID2].NumDeployments)
|
||||
|
||||
// Delete the relation
|
||||
|
||||
updateStackFnTxCalled = false
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.DeleteEndpointRelation(endpointID)
|
||||
|
||||
_, cacheKeyExists = cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments)
|
||||
}
|
||||
|
||||
func TestAddEndpointRelationsForEdgeStack(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
edgeStackService, err := edgestack.NewService(conn, func(t portainer.Transaction, esi portainer.EdgeStackID) {})
|
||||
require.NoError(t, err)
|
||||
|
||||
service.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
|
||||
require.NoError(t, edgeStackService.Create(1, &portainer.EdgeStack{}))
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1, EdgeStacks: map[portainer.EdgeStackID]bool{}}))
|
||||
require.NoError(t, service.AddEndpointRelationsForEdgeStack([]portainer.EndpointID{1}, &portainer.EdgeStack{ID: 1}))
|
||||
}
|
||||
|
||||
func TestEndpointRelations(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1}))
|
||||
rels, err := service.EndpointRelations()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(rels))
|
||||
}
|
||||
@@ -76,14 +76,14 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
for _, endpointID := range endpointIDs {
|
||||
rel, err := service.EndpointRelation(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rel.EdgeStacks[edgeStackID] = true
|
||||
rel.EdgeStacks[edgeStack.ID] = true
|
||||
|
||||
identifier := service.service.connection.ConvertToKey(int(endpointID))
|
||||
err = service.tx.UpdateObject(BucketName, identifier, rel)
|
||||
@@ -93,8 +93,16 @@ func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portaine
|
||||
}
|
||||
}
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments += len(endpointIDs)
|
||||
service.service.mu.Lock()
|
||||
service.service.endpointRelationsCache = nil
|
||||
service.service.mu.Unlock()
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStack.ID, func(es *portainer.EdgeStack) {
|
||||
es.NumDeployments += len(endpointIDs)
|
||||
|
||||
// sync changes in `edgeStack` in case it is re-persisted after `AddEndpointRelationsForEdgeStack` call
|
||||
// to avoid overriding with the previous values
|
||||
edgeStack.NumDeployments = es.NumDeployments
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
@@ -119,6 +127,10 @@ func (service ServiceTx) RemoveEndpointRelationsForEdgeStack(endpointIDs []porta
|
||||
}
|
||||
}
|
||||
|
||||
service.service.mu.Lock()
|
||||
service.service.endpointRelationsCache = nil
|
||||
service.service.mu.Unlock()
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments -= len(endpointIDs)
|
||||
}); err != nil {
|
||||
@@ -178,53 +190,49 @@ func (service ServiceTx) cachedEndpointRelations() ([]portainer.EndpointRelation
|
||||
}
|
||||
|
||||
func (service ServiceTx) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) {
|
||||
relations, _ := service.EndpointRelations()
|
||||
|
||||
stacksToUpdate := map[portainer.EdgeStackID]bool{}
|
||||
|
||||
if previousRelationState != nil {
|
||||
for stackId, enabled := range previousRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the updated relation state
|
||||
// = stack has been removed for this relation
|
||||
// or this relation has been deleted
|
||||
if enabled && (updatedRelationState == nil || !updatedRelationState.EdgeStacks[stackId]) {
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := service.service.updateStackFnTx(service.tx, stackId, func(edgeStack *portainer.EdgeStack) {
|
||||
// Sanity check
|
||||
if edgeStack.NumDeployments <= 0 {
|
||||
log.Error().
|
||||
Int("edgestack_id", int(edgeStack.ID)).
|
||||
Int("endpoint_id", int(previousRelationState.EndpointID)).
|
||||
Int("num_deployments", edgeStack.NumDeployments).
|
||||
Msg("cannot decrement the number of deployments for an edge stack with zero deployments")
|
||||
|
||||
if updatedRelationState != nil {
|
||||
for stackId, enabled := range updatedRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the previous relation state
|
||||
// = stack has been added for this relation
|
||||
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// for each stack referenced by the updated relation
|
||||
// list how many time this stack is referenced in all relations
|
||||
// in order to update the stack deployments count
|
||||
for refStackId, refStackEnabled := range stacksToUpdate {
|
||||
if !refStackEnabled {
|
||||
continue
|
||||
}
|
||||
|
||||
numDeployments := 0
|
||||
|
||||
for _, r := range relations {
|
||||
for sId, enabled := range r.EdgeStacks {
|
||||
if enabled && sId == refStackId {
|
||||
numDeployments += 1
|
||||
edgeStack.NumDeployments--
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
|
||||
cache.Del(previousRelationState.EndpointID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, refStackId, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments = numDeployments
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
if updatedRelationState == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for stackId, enabled := range updatedRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the previous relation state
|
||||
// = stack has been added for this relation
|
||||
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
|
||||
if err := service.service.updateStackFnTx(service.tx, stackId, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments++
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
|
||||
cache.Del(updatedRelationState.EndpointID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ type (
|
||||
EdgeGroup() EdgeGroupService
|
||||
EdgeJob() EdgeJobService
|
||||
EdgeStack() EdgeStackService
|
||||
EdgeStackStatus() EdgeStackStatusService
|
||||
Endpoint() EndpointService
|
||||
EndpointGroup() EndpointGroupService
|
||||
EndpointRelation() EndpointRelationService
|
||||
@@ -39,8 +40,8 @@ type (
|
||||
Open() (newStore bool, err error)
|
||||
Init() error
|
||||
Close() error
|
||||
UpdateTx(func(DataStoreTx) error) error
|
||||
ViewTx(func(DataStoreTx) error) error
|
||||
UpdateTx(func(tx DataStoreTx) error) error
|
||||
ViewTx(func(tx DataStoreTx) error) error
|
||||
MigrateData() error
|
||||
Rollback(force bool) error
|
||||
CheckCurrentEdition() error
|
||||
@@ -89,6 +90,16 @@ type (
|
||||
BucketName() string
|
||||
}
|
||||
|
||||
EdgeStackStatusService interface {
|
||||
Create(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error
|
||||
Read(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) (*portainer.EdgeStackStatusForEnv, error)
|
||||
ReadAll(edgeStackID portainer.EdgeStackID) ([]portainer.EdgeStackStatusForEnv, error)
|
||||
Update(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error
|
||||
Delete(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) error
|
||||
DeleteAll(edgeStackID portainer.EdgeStackID) error
|
||||
Clear(edgeStackID portainer.EdgeStackID, relatedEnvironmentsIDs []portainer.EndpointID) error
|
||||
}
|
||||
|
||||
// EndpointService represents a service for managing environment(endpoint) data
|
||||
EndpointService interface {
|
||||
Endpoint(ID portainer.EndpointID) (*portainer.Endpoint, error)
|
||||
@@ -115,7 +126,7 @@ type (
|
||||
EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error)
|
||||
Create(endpointRelation *portainer.EndpointRelation) error
|
||||
UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error
|
||||
RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
DeleteEndpointRelation(EndpointID portainer.EndpointID) error
|
||||
BucketName() string
|
||||
@@ -159,6 +170,7 @@ type (
|
||||
|
||||
SnapshotService interface {
|
||||
BaseCRUD[portainer.Snapshot, portainer.EndpointID]
|
||||
ReadWithoutSnapshotRaw(ID portainer.EndpointID) (*portainer.Snapshot, error)
|
||||
}
|
||||
|
||||
// SSLSettingsService represents a service for managing application settings
|
||||
|
||||
@@ -1,26 +1,16 @@
|
||||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
BucketName = "pending_actions"
|
||||
)
|
||||
const BucketName = "pending_actions"
|
||||
|
||||
type Service struct {
|
||||
dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
err := connection.SetServiceName(BucketName)
|
||||
if err != nil {
|
||||
@@ -35,6 +25,11 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (s Service) Create(config *portainer.PendingAction) error {
|
||||
return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Create(config)
|
||||
@@ -62,44 +57,3 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.BaseDataServiceTx.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
err := s.BaseDataServiceTx.Delete(pendingAction.ID)
|
||||
if err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
32
api/dataservices/pendingactions/pendingactions_test.go
Normal file
32
api/dataservices/pendingactions/pendingactions_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package pendingactions_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeleteByEndpoint(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
// Create Endpoint 1
|
||||
err := store.PendingActions().Create(&portainer.PendingAction{EndpointID: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create Endpoint 2
|
||||
err = store.PendingActions().Create(&portainer.PendingAction{EndpointID: 2})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete Endpoint 1
|
||||
err = store.PendingActions().DeleteByEndpointID(1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that only Endpoint 2 remains
|
||||
pendingActions, err := store.PendingActions().ReadAll()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pendingActions, 1)
|
||||
require.Equal(t, portainer.EndpointID(2), pendingActions[0].EndpointID)
|
||||
}
|
||||
49
api/dataservices/pendingactions/tx.go
Normal file
49
api/dataservices/pendingactions/tx.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
if err := s.Delete(pendingAction.ID); err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
@@ -38,3 +38,33 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
func (service *Service) Create(snapshot *portainer.Snapshot) error {
|
||||
return service.Connection.CreateObjectWithId(BucketName, int(snapshot.EndpointID), snapshot)
|
||||
}
|
||||
|
||||
func (service *Service) ReadWithoutSnapshotRaw(ID portainer.EndpointID) (*portainer.Snapshot, error) {
|
||||
var snapshot *portainer.Snapshot
|
||||
|
||||
err := service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
snapshot, err = service.Tx(tx).ReadWithoutSnapshotRaw(ID)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return snapshot, err
|
||||
}
|
||||
|
||||
func (service *Service) ReadRawMessage(ID portainer.EndpointID) (*portainer.SnapshotRawMessage, error) {
|
||||
var snapshot *portainer.SnapshotRawMessage
|
||||
|
||||
err := service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
snapshot, err = service.Tx(tx).ReadRawMessage(ID)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return snapshot, err
|
||||
}
|
||||
|
||||
func (service *Service) CreateRawMessage(snapshot *portainer.SnapshotRawMessage) error {
|
||||
return service.Connection.CreateObjectWithId(BucketName, int(snapshot.EndpointID), snapshot)
|
||||
}
|
||||
|
||||
@@ -12,3 +12,42 @@ type ServiceTx struct {
|
||||
func (service ServiceTx) Create(snapshot *portainer.Snapshot) error {
|
||||
return service.Tx.CreateObjectWithId(BucketName, int(snapshot.EndpointID), snapshot)
|
||||
}
|
||||
|
||||
func (service ServiceTx) ReadWithoutSnapshotRaw(ID portainer.EndpointID) (*portainer.Snapshot, error) {
|
||||
var snapshot struct {
|
||||
Docker *struct {
|
||||
X struct{} `json:"DockerSnapshotRaw"`
|
||||
*portainer.DockerSnapshot
|
||||
} `json:"Docker"`
|
||||
|
||||
portainer.Snapshot
|
||||
}
|
||||
|
||||
identifier := service.Connection.ConvertToKey(int(ID))
|
||||
|
||||
if err := service.Tx.GetObject(service.Bucket, identifier, &snapshot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if snapshot.Docker != nil {
|
||||
snapshot.Snapshot.Docker = snapshot.Docker.DockerSnapshot
|
||||
}
|
||||
|
||||
return &snapshot.Snapshot, nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) ReadRawMessage(ID portainer.EndpointID) (*portainer.SnapshotRawMessage, error) {
|
||||
var snapshot = portainer.SnapshotRawMessage{}
|
||||
|
||||
identifier := service.Connection.ConvertToKey(int(ID))
|
||||
|
||||
if err := service.Tx.GetObject(service.Bucket, identifier, &snapshot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &snapshot, nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) CreateRawMessage(snapshot *portainer.SnapshotRawMessage) error {
|
||||
return service.Tx.CreateObjectWithId(BucketName, int(snapshot.EndpointID), snapshot)
|
||||
}
|
||||
|
||||
@@ -232,7 +232,7 @@ func (store *Store) createAccount(username, password string, role portainer.User
|
||||
user := &portainer.User{Username: username, Role: role}
|
||||
|
||||
// encrypt the password
|
||||
cs := &crypto.Service{}
|
||||
cs := crypto.Service{}
|
||||
user.Password, err = cs.Hash(password)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -259,7 +259,7 @@ func (store *Store) checkAccount(username, expectPassword string, expectRole por
|
||||
}
|
||||
|
||||
// Check the password
|
||||
cs := &crypto.Service{}
|
||||
cs := crypto.Service{}
|
||||
expectPasswordHash, err := cs.Hash(expectPassword)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "hash failed")
|
||||
|
||||
@@ -40,13 +40,11 @@ func (store *Store) MigrateData() error {
|
||||
}
|
||||
|
||||
// before we alter anything in the DB, create a backup
|
||||
_, err = store.Backup("")
|
||||
if err != nil {
|
||||
if _, err := store.Backup(""); err != nil {
|
||||
return errors.Wrap(err, "while backing up database")
|
||||
}
|
||||
|
||||
err = store.FailSafeMigrate(migrator, version)
|
||||
if err != nil {
|
||||
if err := store.FailSafeMigrate(migrator, version); err != nil {
|
||||
err = errors.Wrap(err, "failed to migrate database")
|
||||
|
||||
log.Warn().Err(err).Msg("migration failed, restoring database to previous version")
|
||||
@@ -85,7 +83,9 @@ func (store *Store) newMigratorParameters(version *models.Version, flags *portai
|
||||
DockerhubService: store.DockerHubService,
|
||||
AuthorizationService: authorization.NewService(store),
|
||||
EdgeStackService: store.EdgeStackService,
|
||||
EdgeStackStatusService: store.EdgeStackStatusService,
|
||||
EdgeJobService: store.EdgeJobService,
|
||||
EdgeGroupService: store.EdgeGroupService,
|
||||
TunnelServerService: store.TunnelServerService,
|
||||
PendingActionsService: store.PendingActionsService,
|
||||
}
|
||||
@@ -140,8 +140,7 @@ func (store *Store) connectionRollback(force bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
err := store.Restore()
|
||||
if err != nil {
|
||||
if err := store.Restore(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
31
api/datastore/migrator/migrate_2_31_0.go
Normal file
31
api/datastore/migrator/migrate_2_31_0.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package migrator
|
||||
|
||||
import portainer "github.com/portainer/portainer/api"
|
||||
|
||||
func (m *Migrator) migrateEdgeStacksStatuses_2_31_0() error {
|
||||
edgeStacks, err := m.edgeStackService.EdgeStacks()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, edgeStack := range edgeStacks {
|
||||
for envID, status := range edgeStack.Status {
|
||||
if err := m.edgeStackStatusService.Create(edgeStack.ID, envID, &portainer.EdgeStackStatusForEnv{
|
||||
EndpointID: envID,
|
||||
Status: status.Status,
|
||||
DeploymentInfo: status.DeploymentInfo,
|
||||
ReadyRePullImage: status.ReadyRePullImage,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
edgeStack.Status = nil
|
||||
|
||||
if err := m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
33
api/datastore/migrator/migrate_2_32_0.go
Normal file
33
api/datastore/migrator/migrate_2_32_0.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
perrors "github.com/portainer/portainer/api/dataservices/errors"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
)
|
||||
|
||||
func (m *Migrator) addEndpointRelationForEdgeAgents_2_32_0() error {
|
||||
endpoints, err := m.endpointService.Endpoints()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
if endpointutils.IsEdgeEndpoint(&endpoint) {
|
||||
_, err := m.endpointRelationService.EndpointRelation(endpoint.ID)
|
||||
if err != nil && errors.Is(err, perrors.ErrObjectNotFound) {
|
||||
relation := &portainer.EndpointRelation{
|
||||
EndpointID: endpoint.ID,
|
||||
EdgeStacks: make(map[portainer.EdgeStackID]bool),
|
||||
}
|
||||
|
||||
if err := m.endpointRelationService.Create(relation); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
25
api/datastore/migrator/migrate_2_33_0.go
Normal file
25
api/datastore/migrator/migrate_2_33_0.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"github.com/portainer/portainer/api/roar"
|
||||
)
|
||||
|
||||
func (m *Migrator) migrateEdgeGroupEndpointsToRoars_2_33_0() error {
|
||||
egs, err := m.edgeGroupService.ReadAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, eg := range egs {
|
||||
if eg.EndpointIDs.Len() == 0 {
|
||||
eg.EndpointIDs = roar.FromSlice(eg.Endpoints)
|
||||
eg.Endpoints = nil
|
||||
}
|
||||
|
||||
if err := m.edgeGroupService.Update(eg.ID, &eg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
55
api/datastore/migrator/migrate_2_33_test.go
Normal file
55
api/datastore/migrator/migrate_2_33_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/dataservices/edgegroup"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateEdgeGroupEndpointsToRoars_2_33_0Idempotency(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
edgeGroupService, err := edgegroup.NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
edgeGroup := &portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "test-edge-group",
|
||||
Endpoints: []portainer.EndpointID{1, 2, 3},
|
||||
}
|
||||
|
||||
err = conn.CreateObjectWithId(edgegroup.BucketName, int(edgeGroup.ID), edgeGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
m := NewMigrator(&MigratorParameters{EdgeGroupService: edgeGroupService})
|
||||
|
||||
// Run migration once
|
||||
|
||||
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
|
||||
require.NoError(t, err)
|
||||
|
||||
migratedEdgeGroup, err := edgeGroupService.Read(edgeGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, migratedEdgeGroup.Endpoints, 0)
|
||||
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
|
||||
|
||||
// Run migration again to ensure the results didn't change
|
||||
|
||||
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
|
||||
require.NoError(t, err)
|
||||
|
||||
migratedEdgeGroup, err = edgeGroupService.Read(edgeGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, migratedEdgeGroup.Endpoints, 0)
|
||||
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
|
||||
}
|
||||
@@ -94,6 +94,10 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if environmentStatus.Details == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
statusArray := []portainer.EdgeStackDeploymentStatus{}
|
||||
if environmentStatus.Details.Pending {
|
||||
statusArray = append(statusArray, portainer.EdgeStackDeploymentStatus{
|
||||
|
||||
@@ -18,8 +18,7 @@ func (m *Migrator) updateResourceControlsToDBVersion22() error {
|
||||
for _, resourceControl := range legacyResourceControls {
|
||||
resourceControl.AdministratorsOnly = false
|
||||
|
||||
err := m.resourceControlService.Update(resourceControl.ID, &resourceControl)
|
||||
if err != nil {
|
||||
if err := m.resourceControlService.Update(resourceControl.ID, &resourceControl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -42,8 +41,8 @@ func (m *Migrator) updateUsersAndRolesToDBVersion22() error {
|
||||
|
||||
for _, user := range legacyUsers {
|
||||
user.PortainerAuthorizations = authorization.DefaultPortainerAuthorizations()
|
||||
err = m.userService.Update(user.ID, &user)
|
||||
if err != nil {
|
||||
|
||||
if err := m.userService.Update(user.ID, &user); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -52,38 +51,47 @@ func (m *Migrator) updateUsersAndRolesToDBVersion22() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
endpointAdministratorRole.Priority = 1
|
||||
endpointAdministratorRole.Authorizations = authorization.DefaultEndpointAuthorizationsForEndpointAdministratorRole()
|
||||
|
||||
err = m.roleService.Update(endpointAdministratorRole.ID, endpointAdministratorRole)
|
||||
if err := m.roleService.Update(endpointAdministratorRole.ID, endpointAdministratorRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
helpDeskRole, err := m.roleService.Read(portainer.RoleID(2))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
helpDeskRole.Priority = 2
|
||||
helpDeskRole.Authorizations = authorization.DefaultEndpointAuthorizationsForHelpDeskRole(settings.AllowVolumeBrowserForRegularUsers)
|
||||
|
||||
err = m.roleService.Update(helpDeskRole.ID, helpDeskRole)
|
||||
if err := m.roleService.Update(helpDeskRole.ID, helpDeskRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
standardUserRole, err := m.roleService.Read(portainer.RoleID(3))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
standardUserRole.Priority = 3
|
||||
standardUserRole.Authorizations = authorization.DefaultEndpointAuthorizationsForStandardUserRole(settings.AllowVolumeBrowserForRegularUsers)
|
||||
|
||||
err = m.roleService.Update(standardUserRole.ID, standardUserRole)
|
||||
if err := m.roleService.Update(standardUserRole.ID, standardUserRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readOnlyUserRole, err := m.roleService.Read(portainer.RoleID(4))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readOnlyUserRole.Priority = 4
|
||||
readOnlyUserRole.Authorizations = authorization.DefaultEndpointAuthorizationsForReadOnlyUserRole(settings.AllowVolumeBrowserForRegularUsers)
|
||||
|
||||
err = m.roleService.Update(readOnlyUserRole.ID, readOnlyUserRole)
|
||||
if err != nil {
|
||||
if err := m.roleService.Update(readOnlyUserRole.ID, readOnlyUserRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -75,6 +75,10 @@ func (m *Migrator) updateEdgeStackStatusForDB80() error {
|
||||
|
||||
for _, edgeStack := range edgeStacks {
|
||||
for endpointId, status := range edgeStack.Status {
|
||||
if status.Details == nil {
|
||||
status.Details = &portainer.EdgeStackStatusDetails{}
|
||||
}
|
||||
|
||||
switch status.Type {
|
||||
case portainer.EdgeStackStatusPending:
|
||||
status.Details.Pending = true
|
||||
@@ -93,10 +97,10 @@ func (m *Migrator) updateEdgeStackStatusForDB80() error {
|
||||
edgeStack.Status[endpointId] = status
|
||||
}
|
||||
|
||||
err = m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack)
|
||||
if err != nil {
|
||||
if err := m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package migrator
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/Masterminds/semver"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/models"
|
||||
"github.com/portainer/portainer/api/dataservices/dockerhub"
|
||||
"github.com/portainer/portainer/api/dataservices/edgegroup"
|
||||
"github.com/portainer/portainer/api/dataservices/edgejob"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestack"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestackstatus"
|
||||
"github.com/portainer/portainer/api/dataservices/endpoint"
|
||||
"github.com/portainer/portainer/api/dataservices/endpointgroup"
|
||||
"github.com/portainer/portainer/api/dataservices/endpointrelation"
|
||||
@@ -27,6 +28,8 @@ import (
|
||||
"github.com/portainer/portainer/api/dataservices/user"
|
||||
"github.com/portainer/portainer/api/dataservices/version"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
|
||||
"github.com/Masterminds/semver"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -56,7 +59,9 @@ type (
|
||||
authorizationService *authorization.Service
|
||||
dockerhubService *dockerhub.Service
|
||||
edgeStackService *edgestack.Service
|
||||
edgeStackStatusService *edgestackstatus.Service
|
||||
edgeJobService *edgejob.Service
|
||||
edgeGroupService *edgegroup.Service
|
||||
TunnelServerService *tunnelserver.Service
|
||||
pendingActionsService *pendingactions.Service
|
||||
}
|
||||
@@ -84,7 +89,9 @@ type (
|
||||
AuthorizationService *authorization.Service
|
||||
DockerhubService *dockerhub.Service
|
||||
EdgeStackService *edgestack.Service
|
||||
EdgeStackStatusService *edgestackstatus.Service
|
||||
EdgeJobService *edgejob.Service
|
||||
EdgeGroupService *edgegroup.Service
|
||||
TunnelServerService *tunnelserver.Service
|
||||
PendingActionsService *pendingactions.Service
|
||||
}
|
||||
@@ -114,12 +121,15 @@ func NewMigrator(parameters *MigratorParameters) *Migrator {
|
||||
authorizationService: parameters.AuthorizationService,
|
||||
dockerhubService: parameters.DockerhubService,
|
||||
edgeStackService: parameters.EdgeStackService,
|
||||
edgeStackStatusService: parameters.EdgeStackStatusService,
|
||||
edgeJobService: parameters.EdgeJobService,
|
||||
edgeGroupService: parameters.EdgeGroupService,
|
||||
TunnelServerService: parameters.TunnelServerService,
|
||||
pendingActionsService: parameters.PendingActionsService,
|
||||
}
|
||||
|
||||
migrator.initMigrations()
|
||||
|
||||
return migrator
|
||||
}
|
||||
|
||||
@@ -242,6 +252,12 @@ func (m *Migrator) initMigrations() {
|
||||
m.migratePendingActionsDataForDB130,
|
||||
)
|
||||
|
||||
m.addMigrations("2.31.0", m.migrateEdgeStacksStatuses_2_31_0)
|
||||
|
||||
m.addMigrations("2.32.0", m.addEndpointRelationForEdgeAgents_2_32_0)
|
||||
|
||||
m.addMigrations("2.33.1", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
|
||||
|
||||
// Add new migrations above...
|
||||
// One function per migration, each versions migration funcs in the same file.
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package postinit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
@@ -83,17 +84,27 @@ func (postInitMigrator *PostInitMigrator) PostInitMigrate() error {
|
||||
|
||||
// try to create a post init migration pending action. If it already exists, do nothing
|
||||
// this function exists for readability, not reusability
|
||||
// TODO: This should be moved into pending actions as part of the pending action migration
|
||||
func (postInitMigrator *PostInitMigrator) createPostInitMigrationPendingAction(environmentID portainer.EndpointID) error {
|
||||
// If there are no pending actions for the given endpoint, create one
|
||||
err := postInitMigrator.dataStore.PendingActions().Create(&portainer.PendingAction{
|
||||
action := portainer.PendingAction{
|
||||
EndpointID: environmentID,
|
||||
Action: actions.PostInitMigrateEnvironment,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error creating pending action for environment %d", environmentID)
|
||||
}
|
||||
return nil
|
||||
pendingActions, err := postInitMigrator.dataStore.PendingActions().ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending actions: %w", err)
|
||||
}
|
||||
|
||||
for _, dba := range pendingActions {
|
||||
if dba.EndpointID == action.EndpointID && dba.Action == action.Action {
|
||||
log.Debug().
|
||||
Str("action", action.Action).
|
||||
Int("endpoint_id", int(action.EndpointID)).
|
||||
Msg("pending action already exists for environment, skipping...")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return postInitMigrator.dataStore.PendingActions().Create(&action)
|
||||
}
|
||||
|
||||
// MigrateEnvironment runs migrations on a single environment
|
||||
@@ -149,7 +160,7 @@ func (migrator *PostInitMigrator) MigrateGPUs(e portainer.Endpoint, dockerClient
|
||||
return migrator.dataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
environment, err := tx.Endpoint().Endpoint(e.ID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error getting environment %d", environment.ID)
|
||||
log.Error().Err(err).Msgf("Error getting environment %d", e.ID)
|
||||
return err
|
||||
}
|
||||
// Early exit if we do not need to migrate!
|
||||
@@ -175,10 +186,11 @@ func (migrator *PostInitMigrator) MigrateGPUs(e portainer.Endpoint, dockerClient
|
||||
continue
|
||||
}
|
||||
|
||||
deviceRequests := containerDetails.HostConfig.Resources.DeviceRequests
|
||||
deviceRequests := containerDetails.HostConfig.DeviceRequests
|
||||
for _, deviceRequest := range deviceRequests {
|
||||
if deviceRequest.Driver == "nvidia" {
|
||||
environment.EnableGPUManagement = true
|
||||
|
||||
break containersLoop
|
||||
}
|
||||
}
|
||||
@@ -186,8 +198,7 @@ func (migrator *PostInitMigrator) MigrateGPUs(e portainer.Endpoint, dockerClient
|
||||
|
||||
// set the MigrateGPUs flag to false so we don't run this again
|
||||
environment.PostInitMigrations.MigrateGPUs = false
|
||||
err = tx.Endpoint().UpdateEndpoint(environment.ID, environment)
|
||||
if err != nil {
|
||||
if err := tx.Endpoint().UpdateEndpoint(environment.ID, environment); err != nil {
|
||||
log.Error().Err(err).Msgf("Error updating EnableGPUManagement flag for environment %d", environment.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
170
api/datastore/postinit/migrate_post_init_test.go
Normal file
170
api/datastore/postinit/migrate_post_init_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package postinit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/pendingactions/actions"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateGPUs(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/containers/json") {
|
||||
containerSummary := []container.Summary{{ID: "container1"}}
|
||||
|
||||
err := json.NewEncoder(w).Encode(containerSummary)
|
||||
require.NoError(t, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
container := container.InspectResponse{
|
||||
ContainerJSONBase: &container.ContainerJSONBase{
|
||||
ID: "container1",
|
||||
HostConfig: &container.HostConfig{
|
||||
Resources: container.Resources{
|
||||
DeviceRequests: []container.DeviceRequest{
|
||||
{Driver: "nvidia"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := json.NewEncoder(w).Encode(container)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, false)
|
||||
|
||||
migrator := &PostInitMigrator{dataStore: store}
|
||||
|
||||
dockerCli, err := client.NewClientWithOpts(client.WithHost(srv.URL), client.WithHTTPClient(http.DefaultClient))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Nonexistent endpoint
|
||||
|
||||
err = migrator.MigrateGPUs(portainer.Endpoint{}, dockerCli)
|
||||
require.Error(t, err)
|
||||
|
||||
// Valid endpoint
|
||||
|
||||
endpoint := portainer.Endpoint{ID: 1, PostInitMigrations: portainer.EndpointPostInitMigrations{MigrateGPUs: true}}
|
||||
|
||||
err = store.Endpoint().Create(&endpoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = migrator.MigrateGPUs(endpoint, dockerCli)
|
||||
require.NoError(t, err)
|
||||
|
||||
migratedEndpoint, err := store.Endpoint().Endpoint(endpoint.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, endpoint.ID, migratedEndpoint.ID)
|
||||
require.False(t, migratedEndpoint.PostInitMigrations.MigrateGPUs)
|
||||
require.True(t, migratedEndpoint.EnableGPUManagement)
|
||||
}
|
||||
|
||||
func TestPostInitMigrate_PendingActionsCreated(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
existingPendingActions []*portainer.PendingAction
|
||||
expectedPendingActions int
|
||||
expectedAction string
|
||||
}{
|
||||
{
|
||||
name: "when existing non-matching action exists, should add migration action",
|
||||
existingPendingActions: []*portainer.PendingAction{
|
||||
{
|
||||
EndpointID: 7,
|
||||
Action: "some-other-action",
|
||||
},
|
||||
},
|
||||
expectedPendingActions: 2,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
{
|
||||
name: "when matching action exists, should not add duplicate",
|
||||
existingPendingActions: []*portainer.PendingAction{
|
||||
{
|
||||
EndpointID: 7,
|
||||
Action: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
},
|
||||
expectedPendingActions: 1,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
{
|
||||
name: "when no actions exist, should add migration action",
|
||||
existingPendingActions: []*portainer.PendingAction{},
|
||||
expectedPendingActions: 1,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
// Create test endpoint
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: 7,
|
||||
UserTrusted: true,
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
Edge: portainer.EnvironmentEdgeSettings{
|
||||
AsyncMode: false,
|
||||
},
|
||||
EdgeID: "edgeID",
|
||||
}
|
||||
err := store.Endpoint().Create(endpoint)
|
||||
is.NoError(err, "error creating endpoint")
|
||||
|
||||
// Create any existing pending actions
|
||||
for _, action := range tt.existingPendingActions {
|
||||
err = store.PendingActions().Create(action)
|
||||
is.NoError(err, "error creating pending action")
|
||||
}
|
||||
|
||||
migrator := NewPostInitMigrator(
|
||||
nil, // kubeFactory not needed for this test
|
||||
nil, // dockerFactory not needed for this test
|
||||
store,
|
||||
"", // assetsPath not needed for this test
|
||||
nil, // kubernetesDeployer not needed for this test
|
||||
)
|
||||
|
||||
err = migrator.PostInitMigrate()
|
||||
is.NoError(err, "PostInitMigrate should not return error")
|
||||
|
||||
// Verify the results
|
||||
pendingActions, err := store.PendingActions().ReadAll()
|
||||
is.NoError(err, "error reading pending actions")
|
||||
is.Len(pendingActions, tt.expectedPendingActions, "unexpected number of pending actions")
|
||||
|
||||
// If we expect any actions, verify at least one has the expected action type
|
||||
if tt.expectedPendingActions > 0 {
|
||||
hasExpectedAction := false
|
||||
for _, action := range pendingActions {
|
||||
if action.Action == tt.expectedAction {
|
||||
hasExpectedAction = true
|
||||
is.Equal(endpoint.ID, action.EndpointID, "action should reference correct endpoint")
|
||||
break
|
||||
}
|
||||
}
|
||||
is.True(hasExpectedAction, "should have found action of expected type")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/portainer/portainer/api/dataservices/edgegroup"
|
||||
"github.com/portainer/portainer/api/dataservices/edgejob"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestack"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestackstatus"
|
||||
"github.com/portainer/portainer/api/dataservices/endpoint"
|
||||
"github.com/portainer/portainer/api/dataservices/endpointgroup"
|
||||
"github.com/portainer/portainer/api/dataservices/endpointrelation"
|
||||
@@ -39,6 +40,8 @@ import (
|
||||
"github.com/segmentio/encoding/json"
|
||||
)
|
||||
|
||||
var _ dataservices.DataStore = &Store{}
|
||||
|
||||
// Store defines the implementation of portainer.DataStore using
|
||||
// BoltDB as the storage system.
|
||||
type Store struct {
|
||||
@@ -51,6 +54,7 @@ type Store struct {
|
||||
EdgeGroupService *edgegroup.Service
|
||||
EdgeJobService *edgejob.Service
|
||||
EdgeStackService *edgestack.Service
|
||||
EdgeStackStatusService *edgestackstatus.Service
|
||||
EndpointGroupService *endpointgroup.Service
|
||||
EndpointService *endpoint.Service
|
||||
EndpointRelationService *endpointrelation.Service
|
||||
@@ -107,7 +111,13 @@ func (store *Store) initServices() error {
|
||||
return err
|
||||
}
|
||||
store.EdgeStackService = edgeStackService
|
||||
endpointRelationService.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFunc, edgeStackService.UpdateEdgeStackFuncTx)
|
||||
endpointRelationService.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
|
||||
|
||||
edgeStackStatusService, err := edgestackstatus.NewService(store.connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store.EdgeStackStatusService = edgeStackStatusService
|
||||
|
||||
edgeGroupService, err := edgegroup.NewService(store.connection)
|
||||
if err != nil {
|
||||
@@ -269,6 +279,10 @@ func (store *Store) EdgeStack() dataservices.EdgeStackService {
|
||||
return store.EdgeStackService
|
||||
}
|
||||
|
||||
func (store *Store) EdgeStackStatus() dataservices.EdgeStackStatusService {
|
||||
return store.EdgeStackStatusService
|
||||
}
|
||||
|
||||
// Environment(Endpoint) gives access to the Environment(Endpoint) data management layer
|
||||
func (store *Store) Endpoint() dataservices.EndpointService {
|
||||
return store.EndpointService
|
||||
|
||||
@@ -14,7 +14,9 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
|
||||
return tx.store.IsErrObjectNotFound(err)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return nil }
|
||||
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService {
|
||||
return tx.store.CustomTemplateService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) PendingActions() dataservices.PendingActionsService {
|
||||
return tx.store.PendingActionsService.Tx(tx.tx)
|
||||
@@ -32,6 +34,10 @@ func (tx *StoreTx) EdgeStack() dataservices.EdgeStackService {
|
||||
return tx.store.EdgeStackService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) EdgeStackStatus() dataservices.EdgeStackStatusService {
|
||||
return tx.store.EdgeStackStatusService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) Endpoint() dataservices.EndpointService {
|
||||
return tx.store.EndpointService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
}
|
||||
],
|
||||
"edge_stack": null,
|
||||
"edge_stack_status": null,
|
||||
"edgegroups": null,
|
||||
"edgejobs": null,
|
||||
"endpoint_groups": [
|
||||
@@ -120,6 +121,10 @@
|
||||
"Ecr": {
|
||||
"Region": ""
|
||||
},
|
||||
"Github": {
|
||||
"OrganisationName": "",
|
||||
"UseOrganisation": false
|
||||
},
|
||||
"Gitlab": {
|
||||
"InstanceURL": "",
|
||||
"ProjectId": 0,
|
||||
@@ -605,12 +610,12 @@
|
||||
"GlobalDeploymentOptions": {
|
||||
"hideStacksFunctionality": false
|
||||
},
|
||||
"HelmRepositoryURL": "",
|
||||
"HelmRepositoryURL": "https://charts.bitnami.com/bitnami",
|
||||
"InternalAuthSettings": {
|
||||
"RequiredPasswordLength": 12
|
||||
},
|
||||
"KubeconfigExpiry": "0",
|
||||
"KubectlShellImage": "portainer/kubectl-shell:2.27.1",
|
||||
"KubectlShellImage": "portainer/kubectl-shell:2.33.6",
|
||||
"LDAPSettings": {
|
||||
"AnonymousMode": true,
|
||||
"AutoCreateUsers": true,
|
||||
@@ -678,14 +683,11 @@
|
||||
"Images": null,
|
||||
"Info": {
|
||||
"Architecture": "",
|
||||
"BridgeNfIp6tables": false,
|
||||
"BridgeNfIptables": false,
|
||||
"CDISpecDirs": null,
|
||||
"CPUSet": false,
|
||||
"CPUShares": false,
|
||||
"CgroupDriver": "",
|
||||
"ContainerdCommit": {
|
||||
"Expected": "",
|
||||
"ID": ""
|
||||
},
|
||||
"Containers": 0,
|
||||
@@ -709,7 +711,6 @@
|
||||
"IndexServerAddress": "",
|
||||
"InitBinary": "",
|
||||
"InitCommit": {
|
||||
"Expected": "",
|
||||
"ID": ""
|
||||
},
|
||||
"Isolation": "",
|
||||
@@ -738,7 +739,6 @@
|
||||
},
|
||||
"RegistryConfig": null,
|
||||
"RuncCommit": {
|
||||
"Expected": "",
|
||||
"ID": ""
|
||||
},
|
||||
"Runtimes": null,
|
||||
@@ -780,6 +780,7 @@
|
||||
"ImageCount": 9,
|
||||
"IsPodman": false,
|
||||
"NodeCount": 0,
|
||||
"PerformanceMetrics": null,
|
||||
"RunningContainerCount": 5,
|
||||
"ServiceCount": 0,
|
||||
"StackCount": 2,
|
||||
@@ -943,7 +944,7 @@
|
||||
}
|
||||
],
|
||||
"version": {
|
||||
"VERSION": "{\"SchemaVersion\":\"2.27.1\",\"MigratorCount\":0,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
|
||||
"VERSION": "{\"SchemaVersion\":\"2.33.6\",\"MigratorCount\":0,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
|
||||
},
|
||||
"webhooks": null
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func NewTestStore(t testing.TB, init, secure bool) (bool, *Store, func(), error)
|
||||
secretKey = nil
|
||||
}
|
||||
|
||||
connection, err := database.NewDatabase("boltdb", storePath, secretKey)
|
||||
connection, err := database.NewDatabase("boltdb", storePath, secretKey, false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package validate
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
var validate *validator.Validate
|
||||
|
||||
func ValidateLDAPSettings(ldp *portainer.LDAPSettings) error {
|
||||
validate = validator.New()
|
||||
registerValidationMethods(validate)
|
||||
|
||||
return validate.Struct(ldp)
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package validate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
func TestValidateLDAPSettings(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ldap portainer.LDAPSettings
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty LDAP Settings",
|
||||
ldap: portainer.LDAPSettings{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "With URL",
|
||||
ldap: portainer.LDAPSettings{
|
||||
AnonymousMode: true,
|
||||
URL: "192.168.0.1:323",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Validate URL and URLs",
|
||||
ldap: portainer.LDAPSettings{
|
||||
AnonymousMode: true,
|
||||
URL: "192.168.0.1:323",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "validate client ldap",
|
||||
ldap: portainer.LDAPSettings{
|
||||
AnonymousMode: false,
|
||||
ReaderDN: "CN=LDAP API Service Account",
|
||||
Password: "Qu**dfUUU**",
|
||||
URL: "aukdc15.pgc.co:389",
|
||||
TLSConfig: portainer.TLSConfiguration{
|
||||
TLS: false,
|
||||
TLSSkipVerify: false,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateLDAPSettings(&tt.ldap)
|
||||
if (err == nil) == tt.wantErr {
|
||||
t.Errorf("No error expected but got %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package validate
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func registerValidationMethods(v *validator.Validate) {
|
||||
v.RegisterValidation("validate_bool", ValidateBool)
|
||||
}
|
||||
|
||||
/**
|
||||
* Validation methods below are being used for custom validation
|
||||
*/
|
||||
func ValidateBool(fl validator.FieldLevel) bool {
|
||||
_, ok := fl.Field().Interface().(bool)
|
||||
return ok
|
||||
}
|
||||
@@ -24,6 +24,8 @@ const (
|
||||
dockerClientVersion = "1.37"
|
||||
)
|
||||
|
||||
type NodeNamesCtxKey struct{}
|
||||
|
||||
// ClientFactory is used to create Docker clients
|
||||
type ClientFactory struct {
|
||||
signatureService portainer.DigitalSignatureService
|
||||
@@ -73,19 +75,6 @@ func createLocalClient(endpoint *portainer.Endpoint) (*client.Client, error) {
|
||||
)
|
||||
}
|
||||
|
||||
func CreateClientFromEnv() (*client.Client, error) {
|
||||
return client.NewClientWithOpts(
|
||||
client.FromEnv,
|
||||
client.WithAPIVersionNegotiation(),
|
||||
)
|
||||
}
|
||||
|
||||
func CreateSimpleClient() (*client.Client, error) {
|
||||
return client.NewClientWithOpts(
|
||||
client.WithAPIVersionNegotiation(),
|
||||
)
|
||||
}
|
||||
|
||||
func createTCPClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*client.Client, error) {
|
||||
httpCli, err := httpClient(endpoint, timeout)
|
||||
if err != nil {
|
||||
@@ -175,7 +164,7 @@ func (t *NodeNameTransport) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
nodeNames, ok := req.Context().Value("nodeNames").(map[string]string)
|
||||
nodeNames, ok := req.Context().Value(NodeNamesCtxKey{}).(map[string]string)
|
||||
if ok {
|
||||
for idx, r := range rs {
|
||||
// as there is no way to differentiate the same image available in multiple nodes only by their ID
|
||||
@@ -194,10 +183,11 @@ func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Cli
|
||||
}
|
||||
|
||||
if endpoint.TLSConfig.TLS {
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig.TLSCACertPath, endpoint.TLSConfig.TLSCertPath, endpoint.TLSConfig.TLSKeyPath, endpoint.TLSConfig.TLSSkipVerify)
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
}
|
||||
|
||||
|
||||
30
api/docker/client/client_test.go
Normal file
30
api/docker/client/client_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHttpClient(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
// Valid TLS configuration
|
||||
endpoint := &portainer.Endpoint{}
|
||||
endpoint.TLSConfig = portainer.TLSConfiguration{TLS: true}
|
||||
|
||||
cli, err := httpClient(endpoint, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cli)
|
||||
|
||||
// Invalid TLS configuration
|
||||
endpoint.TLSConfig.TLSCertPath = "/invalid/path/client.crt"
|
||||
endpoint.TLSConfig.TLSKeyPath = "/invalid/path/client.key"
|
||||
|
||||
cli, err = httpClient(endpoint, nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, cli)
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
"github.com/containers/image/v5/docker"
|
||||
imagetypes "github.com/containers/image/v5/types"
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/image"
|
||||
"github.com/opencontainers/go-digest"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -38,10 +38,10 @@ func NewClientWithRegistry(registryClient *RegistryClient, clientFactory *docker
|
||||
func (c *DigestClient) RemoteDigest(image Image) (digest.Digest, error) {
|
||||
ctx, cancel := c.timeoutContext()
|
||||
defer cancel()
|
||||
|
||||
// Docker references with both a tag and digest are currently not supported
|
||||
if image.Tag != "" && image.Digest != "" {
|
||||
err := image.trimDigest()
|
||||
if err != nil {
|
||||
if err := image.TrimDigest(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func (c *DigestClient) RemoteDigest(image Image) (digest.Digest, error) {
|
||||
// Retrieve remote digest through HEAD request
|
||||
rmDigest, err := docker.GetDigest(ctx, sysCtx, rmRef)
|
||||
if err != nil {
|
||||
// fallback to public registry for hub
|
||||
// Fallback to public registry for hub
|
||||
if image.HubLink != "" {
|
||||
rmDigest, err = docker.GetDigest(ctx, c.sysCtx, rmRef)
|
||||
if err == nil {
|
||||
@@ -85,7 +85,7 @@ func (c *DigestClient) RemoteDigest(image Image) (digest.Digest, error) {
|
||||
return rmDigest, nil
|
||||
}
|
||||
|
||||
func ParseLocalImage(inspect types.ImageInspect) (*Image, error) {
|
||||
func ParseLocalImage(inspect image.InspectResponse) (*Image, error) {
|
||||
if IsLocalImage(inspect) || IsDanglingImage(inspect) {
|
||||
return nil, errors.New("the image is not regular")
|
||||
}
|
||||
@@ -131,8 +131,7 @@ func ParseRepoDigests(repoDigests []string) []digest.Digest {
|
||||
func ParseRepoTags(repoTags []string) []*Image {
|
||||
images := make([]*Image, 0)
|
||||
for _, repoTag := range repoTags {
|
||||
image := ParseRepoTag(repoTag)
|
||||
if image != nil {
|
||||
if image := ParseRepoTag(repoTag); image != nil {
|
||||
images = append(images, image)
|
||||
}
|
||||
}
|
||||
@@ -147,7 +146,7 @@ func ParseRepoDigest(repoDigest string) digest.Digest {
|
||||
|
||||
d, err := digest.Parse(strings.Split(repoDigest, "@")[1])
|
||||
if err != nil {
|
||||
log.Warn().Msgf("Skip invalid repo digest item: %s [error: %v]", repoDigest, err)
|
||||
log.Warn().Err(err).Str("digest", repoDigest).Msg("skip invalid repo item")
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
33
api/docker/images/digest_test.go
Normal file
33
api/docker/images/digest_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package images
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/docker/docker/api/types/image"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseLocalImage(t *testing.T) {
|
||||
// Test with a regular image
|
||||
|
||||
img, err := ParseLocalImage(image.InspectResponse{
|
||||
ID: "sha256:9234e8fb04c47cfe0f49931e4ac7eb76fa904e33b7f8576aec0501c085f02516",
|
||||
RepoTags: []string{"myimage:latest"},
|
||||
RepoDigests: []string{"myimage@sha256:4bcff63911fcb4448bd4fdacec207030997caf25e9bea4045fa6c8c44de311d1"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, img)
|
||||
require.Equal(t, "library/myimage", img.Path)
|
||||
require.Equal(t, "latest", img.Tag)
|
||||
require.Equal(t, "sha256:4bcff63911fcb4448bd4fdacec207030997caf25e9bea4045fa6c8c44de311d1", img.Digest.String())
|
||||
|
||||
// Test with a dangling image
|
||||
|
||||
img, err = ParseLocalImage(image.InspectResponse{
|
||||
ID: "sha256:abcdef1234567890",
|
||||
RepoTags: []string{"<none>:<none>"},
|
||||
RepoDigests: []string{"<none>@<none>"},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, img)
|
||||
}
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
|
||||
"github.com/containers/image/v5/docker/reference"
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/image"
|
||||
"github.com/opencontainers/go-digest"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@@ -26,7 +26,7 @@ type Image struct {
|
||||
Digest digest.Digest
|
||||
HubLink string
|
||||
named reference.Named
|
||||
opts ParseImageOptions
|
||||
Opts ParseImageOptions `json:"-"`
|
||||
}
|
||||
|
||||
// ParseImageOptions holds image options for parsing.
|
||||
@@ -43,9 +43,10 @@ func (i *Image) Name() string {
|
||||
// FullName return the real full name may include Tag or Digest of the image, Tag first.
|
||||
func (i *Image) FullName() string {
|
||||
if i.Tag == "" {
|
||||
return fmt.Sprintf("%s@%s", i.Name(), i.Digest)
|
||||
return i.Name() + "@" + i.Digest.String()
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", i.Name(), i.Tag)
|
||||
|
||||
return i.Name() + ":" + i.Tag
|
||||
}
|
||||
|
||||
// String returns the string representation of an image, including Tag and Digest if existed.
|
||||
@@ -66,22 +67,25 @@ func (i *Image) Reference() string {
|
||||
func (i *Image) WithDigest(digest digest.Digest) (err error) {
|
||||
i.Digest = digest
|
||||
i.named, err = reference.WithDigest(i.named, digest)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *Image) WithTag(tag string) (err error) {
|
||||
i.Tag = tag
|
||||
i.named, err = reference.WithTag(i.named, tag)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *Image) trimDigest() error {
|
||||
func (i *Image) TrimDigest() error {
|
||||
i.Digest = ""
|
||||
named, err := ParseImage(ParseImageOptions{Name: i.FullName()})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.named = &named
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -92,11 +96,12 @@ func ParseImage(parseOpts ParseImageOptions) (Image, error) {
|
||||
if err != nil {
|
||||
return Image{}, errors.Wrapf(err, "parsing image %s failed", parseOpts.Name)
|
||||
}
|
||||
|
||||
// Add the latest lag if they did not provide one.
|
||||
named = reference.TagNameOnly(named)
|
||||
|
||||
i := Image{
|
||||
opts: parseOpts,
|
||||
Opts: parseOpts,
|
||||
named: named,
|
||||
Domain: reference.Domain(named),
|
||||
Path: reference.Path(named),
|
||||
@@ -122,15 +127,16 @@ func ParseImage(parseOpts ParseImageOptions) (Image, error) {
|
||||
}
|
||||
|
||||
func (i *Image) hubLink() (string, error) {
|
||||
if i.opts.HubTpl != "" {
|
||||
if i.Opts.HubTpl != "" {
|
||||
var out bytes.Buffer
|
||||
tmpl, err := template.New("tmpl").
|
||||
Option("missingkey=error").
|
||||
Parse(i.opts.HubTpl)
|
||||
Parse(i.Opts.HubTpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = tmpl.Execute(&out, i)
|
||||
|
||||
return out.String(), err
|
||||
}
|
||||
|
||||
@@ -142,6 +148,7 @@ func (i *Image) hubLink() (string, error) {
|
||||
prefix = "_"
|
||||
path = strings.Replace(i.Path, "library/", "", 1)
|
||||
}
|
||||
|
||||
return "https://hub.docker.com/" + prefix + "/" + path, nil
|
||||
case "docker.bintray.io", "jfrog-docker-reg2.bintray.io":
|
||||
return "https://bintray.com/jfrog/reg2/" + strings.ReplaceAll(i.Path, "/", "%3A"), nil
|
||||
@@ -172,11 +179,11 @@ func IsLocalImage(image types.ImageInspect) bool {
|
||||
// IsDanglingImage returns whether the given image is "dangling" which means
|
||||
// that there are no repository references to the given image and it has no
|
||||
// child images
|
||||
func IsDanglingImage(image types.ImageInspect) bool {
|
||||
func IsDanglingImage(image image.InspectResponse) bool {
|
||||
return len(image.RepoTags) == 1 && image.RepoTags[0] == "<none>:<none>" && len(image.RepoDigests) == 1 && image.RepoDigests[0] == "<none>@<none>"
|
||||
}
|
||||
|
||||
// IsNoTagImage returns whether the given image is damaged, has no tags
|
||||
func IsNoTagImage(image types.ImageInspect) bool {
|
||||
func IsNoTagImage(image image.InspectResponse) bool {
|
||||
return len(image.RepoTags) == 0
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestImageParser(t *testing.T) {
|
||||
})
|
||||
is.NoError(err, "")
|
||||
is.Equal("docker.io/portainer/portainer-ee:latest", image.FullName())
|
||||
is.Equal("portainer/portainer-ee", image.opts.Name)
|
||||
is.Equal("portainer/portainer-ee", image.Opts.Name)
|
||||
is.Equal("latest", image.Tag)
|
||||
is.Equal("portainer/portainer-ee", image.Path)
|
||||
is.Equal("docker.io", image.Domain)
|
||||
@@ -32,7 +32,7 @@ func TestImageParser(t *testing.T) {
|
||||
})
|
||||
is.NoError(err, "")
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.opts.Name)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("", image.Tag)
|
||||
is.Equal("k8s-minikube/kicbase", image.Path)
|
||||
is.Equal("gcr.io", image.Domain)
|
||||
@@ -49,7 +49,7 @@ func TestImageParser(t *testing.T) {
|
||||
})
|
||||
is.NoError(err, "")
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.opts.Name)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.30", image.Tag)
|
||||
is.Equal("k8s-minikube/kicbase", image.Path)
|
||||
is.Equal("gcr.io", image.Domain)
|
||||
@@ -71,7 +71,7 @@ func TestUpdateParsedImage(t *testing.T) {
|
||||
is.NoError(err, "")
|
||||
_ = image.WithTag("v0.0.31")
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.31", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.opts.Name)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.31", image.Tag)
|
||||
is.Equal("k8s-minikube/kicbase", image.Path)
|
||||
is.Equal("gcr.io", image.Domain)
|
||||
@@ -89,7 +89,7 @@ func TestUpdateParsedImage(t *testing.T) {
|
||||
is.NoError(err, "")
|
||||
_ = image.WithDigest("sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b3")
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.opts.Name)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.30", image.Tag)
|
||||
is.Equal("k8s-minikube/kicbase", image.Path)
|
||||
is.Equal("gcr.io", image.Domain)
|
||||
@@ -105,9 +105,9 @@ func TestUpdateParsedImage(t *testing.T) {
|
||||
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
|
||||
})
|
||||
is.NoError(err, "")
|
||||
_ = image.trimDigest()
|
||||
_ = image.TrimDigest()
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.opts.Name)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.30", image.Tag)
|
||||
is.Equal("k8s-minikube/kicbase", image.Path)
|
||||
is.Equal("gcr.io", image.Domain)
|
||||
|
||||
@@ -29,7 +29,7 @@ func (c *RegistryClient) RegistryAuth(image Image) (string, string, error) {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
registry, err := findBestMatchRegistry(image.opts.Name, registries)
|
||||
registry, err := findBestMatchRegistry(image.Opts.Name, registries)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (c *RegistryClient) EncodedRegistryAuth(image Image) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
registry, err := findBestMatchRegistry(image.opts.Name, registries)
|
||||
registry, err := findBestMatchRegistry(image.Opts.Name, registries)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -49,6 +49,11 @@ type (
|
||||
|
||||
// Is relative path supported
|
||||
SupportRelativePath bool
|
||||
// AlwaysCloneGitRepoForRelativePath is a flag indicating if the agent must always clone the git repository for relative path.
|
||||
// This field is only valid when SupportRelativePath is true.
|
||||
// Used only for EE
|
||||
AlwaysCloneGitRepoForRelativePath bool
|
||||
|
||||
// Mount point for relative path
|
||||
FilesystemPath string
|
||||
// Used only for EE
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
type kubernetesMockDeployer struct{}
|
||||
type kubernetesMockDeployer struct {
|
||||
portainer.KubernetesDeployer
|
||||
}
|
||||
|
||||
// NewKubernetesDeployer creates a mock kubernetes deployer
|
||||
func NewKubernetesDeployer() portainer.KubernetesDeployer {
|
||||
func NewKubernetesDeployer() *kubernetesMockDeployer {
|
||||
return &kubernetesMockDeployer{}
|
||||
}
|
||||
|
||||
@@ -18,3 +20,7 @@ func (deployer *kubernetesMockDeployer) Deploy(userID portainer.UserID, endpoint
|
||||
func (deployer *kubernetesMockDeployer) Remove(userID portainer.UserID, endpoint *portainer.Endpoint, manifestFiles []string, namespace string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (deployer *kubernetesMockDeployer) Restart(userID portainer.UserID, endpoint *portainer.Endpoint, manifestFiles []string, namespace string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
package exec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
@@ -15,13 +10,17 @@ import (
|
||||
"github.com/portainer/portainer/api/http/proxy/factory"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
|
||||
"github.com/portainer/portainer/api/kubernetes/cli"
|
||||
"github.com/portainer/portainer/pkg/libkubectl"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultServerURL = "https://kubernetes.default.svc"
|
||||
)
|
||||
|
||||
// KubernetesDeployer represents a service to deploy resources inside a Kubernetes environment(endpoint).
|
||||
type KubernetesDeployer struct {
|
||||
binaryPath string
|
||||
dataStore dataservices.DataStore
|
||||
reverseTunnelService portainer.ReverseTunnelService
|
||||
signatureService portainer.DigitalSignatureService
|
||||
@@ -31,9 +30,8 @@ type KubernetesDeployer struct {
|
||||
}
|
||||
|
||||
// NewKubernetesDeployer initializes a new KubernetesDeployer service.
|
||||
func NewKubernetesDeployer(kubernetesTokenCacheManager *kubernetes.TokenCacheManager, kubernetesClientFactory *cli.ClientFactory, datastore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager, binaryPath string) *KubernetesDeployer {
|
||||
func NewKubernetesDeployer(kubernetesTokenCacheManager *kubernetes.TokenCacheManager, kubernetesClientFactory *cli.ClientFactory, datastore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager) *KubernetesDeployer {
|
||||
return &KubernetesDeployer{
|
||||
binaryPath: binaryPath,
|
||||
dataStore: datastore,
|
||||
reverseTunnelService: reverseTunnelService,
|
||||
signatureService: signatureService,
|
||||
@@ -78,63 +76,56 @@ func (deployer *KubernetesDeployer) getToken(userID portainer.UserID, endpoint *
|
||||
}
|
||||
|
||||
// Deploy upserts Kubernetes resources defined in manifest(s)
|
||||
func (deployer *KubernetesDeployer) Deploy(userID portainer.UserID, endpoint *portainer.Endpoint, manifestFiles []string, namespace string) (string, error) {
|
||||
return deployer.command("apply", userID, endpoint, manifestFiles, namespace)
|
||||
func (deployer *KubernetesDeployer) Deploy(userID portainer.UserID, endpoint *portainer.Endpoint, resources []string, namespace string) (string, error) {
|
||||
return deployer.command("apply", userID, endpoint, resources, namespace)
|
||||
}
|
||||
|
||||
// Remove deletes Kubernetes resources defined in manifest(s)
|
||||
func (deployer *KubernetesDeployer) Remove(userID portainer.UserID, endpoint *portainer.Endpoint, manifestFiles []string, namespace string) (string, error) {
|
||||
return deployer.command("delete", userID, endpoint, manifestFiles, namespace)
|
||||
func (deployer *KubernetesDeployer) Remove(userID portainer.UserID, endpoint *portainer.Endpoint, resources []string, namespace string) (string, error) {
|
||||
return deployer.command("delete", userID, endpoint, resources, namespace)
|
||||
}
|
||||
|
||||
func (deployer *KubernetesDeployer) command(operation string, userID portainer.UserID, endpoint *portainer.Endpoint, manifestFiles []string, namespace string) (string, error) {
|
||||
func (deployer *KubernetesDeployer) command(operation string, userID portainer.UserID, endpoint *portainer.Endpoint, resources []string, namespace string) (string, error) {
|
||||
token, err := deployer.getToken(userID, endpoint, endpoint.Type == portainer.KubernetesLocalEnvironment)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed generating a user token")
|
||||
}
|
||||
|
||||
command := path.Join(deployer.binaryPath, "kubectl")
|
||||
if runtime.GOOS == "windows" {
|
||||
command = path.Join(deployer.binaryPath, "kubectl.exe")
|
||||
}
|
||||
|
||||
args := []string{"--token", token}
|
||||
if namespace != "" {
|
||||
args = append(args, "--namespace", namespace)
|
||||
}
|
||||
|
||||
serverURL := defaultServerURL
|
||||
if endpoint.Type == portainer.AgentOnKubernetesEnvironment || endpoint.Type == portainer.EdgeAgentOnKubernetesEnvironment {
|
||||
url, proxy, err := deployer.getAgentURL(endpoint)
|
||||
if err != nil {
|
||||
return "", errors.WithMessage(err, "failed generating endpoint URL")
|
||||
}
|
||||
|
||||
defer proxy.Close()
|
||||
args = append(args, "--server", url)
|
||||
args = append(args, "--insecure-skip-tls-verify")
|
||||
|
||||
serverURL = url
|
||||
}
|
||||
|
||||
if operation == "delete" {
|
||||
args = append(args, "--ignore-not-found=true")
|
||||
}
|
||||
|
||||
args = append(args, operation)
|
||||
for _, path := range manifestFiles {
|
||||
args = append(args, "-f", strings.TrimSpace(path))
|
||||
}
|
||||
|
||||
var stderr bytes.Buffer
|
||||
cmd := exec.Command(command, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = append(cmd.Env, "POD_NAMESPACE=default")
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
output, err := cmd.Output()
|
||||
client, err := libkubectl.NewClient(&libkubectl.ClientAccess{
|
||||
Token: token,
|
||||
ServerUrl: serverURL,
|
||||
}, namespace, "", true)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to execute kubectl command: %q", stderr.String())
|
||||
return "", errors.Wrap(err, "failed to create kubectl client")
|
||||
}
|
||||
|
||||
return string(output), nil
|
||||
operations := map[string]func(context.Context, []string) (string, error){
|
||||
"apply": client.Apply,
|
||||
"delete": client.Delete,
|
||||
}
|
||||
|
||||
operationFunc, ok := operations[operation]
|
||||
if !ok {
|
||||
return "", errors.Errorf("unsupported operation: %s", operation)
|
||||
}
|
||||
|
||||
output, err := operationFunc(context.Background(), resources)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to execute kubectl %s command", operation)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (deployer *KubernetesDeployer) getAgentURL(endpoint *portainer.Endpoint) (string, *factory.ProxyServer, error) {
|
||||
|
||||
173
api/exec/kubernetes_deploy_test.go
Normal file
173
api/exec/kubernetes_deploy_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package exec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockKubectlClient struct {
|
||||
applyFunc func(ctx context.Context, files []string) error
|
||||
deleteFunc func(ctx context.Context, files []string) error
|
||||
rolloutRestartFunc func(ctx context.Context, resources []string) error
|
||||
}
|
||||
|
||||
func (m *mockKubectlClient) Apply(ctx context.Context, files []string) error {
|
||||
if m.applyFunc != nil {
|
||||
return m.applyFunc(ctx, files)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockKubectlClient) Delete(ctx context.Context, files []string) error {
|
||||
if m.deleteFunc != nil {
|
||||
return m.deleteFunc(ctx, files)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockKubectlClient) RolloutRestart(ctx context.Context, resources []string) error {
|
||||
if m.rolloutRestartFunc != nil {
|
||||
return m.rolloutRestartFunc(ctx, resources)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func testExecuteKubectlOperation(client *mockKubectlClient, operation string, manifestFiles []string) error {
|
||||
operations := map[string]func(context.Context, []string) error{
|
||||
"apply": client.Apply,
|
||||
"delete": client.Delete,
|
||||
"rollout-restart": client.RolloutRestart,
|
||||
}
|
||||
|
||||
operationFunc, ok := operations[operation]
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported operation: %s", operation)
|
||||
}
|
||||
|
||||
if err := operationFunc(context.Background(), manifestFiles); err != nil {
|
||||
return fmt.Errorf("failed to execute kubectl %s command: %w", operation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestExecuteKubectlOperation_Apply_Success(t *testing.T) {
|
||||
called := false
|
||||
mockClient := &mockKubectlClient{
|
||||
applyFunc: func(ctx context.Context, files []string) error {
|
||||
called = true
|
||||
assert.Equal(t, []string{"manifest1.yaml", "manifest2.yaml"}, files)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
manifests := []string{"manifest1.yaml", "manifest2.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "apply", manifests)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestExecuteKubectlOperation_Apply_Error(t *testing.T) {
|
||||
expectedErr := errors.New("kubectl apply failed")
|
||||
called := false
|
||||
mockClient := &mockKubectlClient{
|
||||
applyFunc: func(ctx context.Context, files []string) error {
|
||||
called = true
|
||||
assert.Equal(t, []string{"error.yaml"}, files)
|
||||
return expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
manifests := []string{"error.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "apply", manifests)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), expectedErr.Error())
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestExecuteKubectlOperation_Delete_Success(t *testing.T) {
|
||||
called := false
|
||||
mockClient := &mockKubectlClient{
|
||||
deleteFunc: func(ctx context.Context, files []string) error {
|
||||
called = true
|
||||
assert.Equal(t, []string{"manifest1.yaml"}, files)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
manifests := []string{"manifest1.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "delete", manifests)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestExecuteKubectlOperation_Delete_Error(t *testing.T) {
|
||||
expectedErr := errors.New("kubectl delete failed")
|
||||
called := false
|
||||
mockClient := &mockKubectlClient{
|
||||
deleteFunc: func(ctx context.Context, files []string) error {
|
||||
called = true
|
||||
assert.Equal(t, []string{"error.yaml"}, files)
|
||||
return expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
manifests := []string{"error.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "delete", manifests)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), expectedErr.Error())
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestExecuteKubectlOperation_RolloutRestart_Success(t *testing.T) {
|
||||
called := false
|
||||
mockClient := &mockKubectlClient{
|
||||
rolloutRestartFunc: func(ctx context.Context, resources []string) error {
|
||||
called = true
|
||||
assert.Equal(t, []string{"deployment/nginx"}, resources)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resources := []string{"deployment/nginx"}
|
||||
err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestExecuteKubectlOperation_RolloutRestart_Error(t *testing.T) {
|
||||
expectedErr := errors.New("kubectl rollout restart failed")
|
||||
called := false
|
||||
mockClient := &mockKubectlClient{
|
||||
rolloutRestartFunc: func(ctx context.Context, resources []string) error {
|
||||
called = true
|
||||
assert.Equal(t, []string{"deployment/error"}, resources)
|
||||
return expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
resources := []string{"deployment/error"}
|
||||
err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), expectedErr.Error())
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestExecuteKubectlOperation_UnsupportedOperation(t *testing.T) {
|
||||
mockClient := &mockKubectlClient{}
|
||||
|
||||
err := testExecuteKubectlOperation(mockClient, "unsupported", []string{})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported operation")
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func copyFile(src, dst string) error {
|
||||
defer from.Close()
|
||||
|
||||
// has to include 'execute' bit, otherwise fails. MkdirAll follows `mkdir -m` restrictions
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0744); err != nil {
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
to, err := os.Create(dst)
|
||||
|
||||
@@ -848,7 +848,7 @@ func defaultMTLSCertPathUnderFileStore() (string, string, string) {
|
||||
return caCertPath, certPath, keyPath
|
||||
}
|
||||
|
||||
// GetDefaultChiselPrivateKeyPath returns the chisle private key path
|
||||
// GetDefaultChiselPrivateKeyPath returns the chisel private key path
|
||||
func (service *Service) GetDefaultChiselPrivateKeyPath() string {
|
||||
privateKeyPath := defaultChiselPrivateKeyPathUnderFileStore()
|
||||
return service.wrapFileStore(privateKeyPath)
|
||||
|
||||
@@ -15,15 +15,19 @@ type MultiFilterArgs []struct {
|
||||
}
|
||||
|
||||
// MultiFilterDirForPerDevConfigs filers the given dirEntries with multiple filter args, returns the merged entries for the given device
|
||||
func MultiFilterDirForPerDevConfigs(dirEntries []DirEntry, configPath string, multiFilterArgs MultiFilterArgs) []DirEntry {
|
||||
func MultiFilterDirForPerDevConfigs(dirEntries []DirEntry, configPath string, multiFilterArgs MultiFilterArgs) ([]DirEntry, []string) {
|
||||
var filteredDirEntries []DirEntry
|
||||
|
||||
var envFiles []string
|
||||
|
||||
for _, multiFilterArg := range multiFilterArgs {
|
||||
tmp := FilterDirForPerDevConfigs(dirEntries, multiFilterArg.FilterKey, configPath, multiFilterArg.FilterType)
|
||||
tmp, efs := FilterDirForPerDevConfigs(dirEntries, multiFilterArg.FilterKey, configPath, multiFilterArg.FilterType)
|
||||
filteredDirEntries = append(filteredDirEntries, tmp...)
|
||||
|
||||
envFiles = append(envFiles, efs...)
|
||||
}
|
||||
|
||||
return deduplicate(filteredDirEntries)
|
||||
return deduplicate(filteredDirEntries), envFiles
|
||||
}
|
||||
|
||||
func deduplicate(dirEntries []DirEntry) []DirEntry {
|
||||
@@ -32,8 +36,7 @@ func deduplicate(dirEntries []DirEntry) []DirEntry {
|
||||
marks := make(map[string]struct{})
|
||||
|
||||
for _, dirEntry := range dirEntries {
|
||||
_, ok := marks[dirEntry.Name]
|
||||
if !ok {
|
||||
if _, ok := marks[dirEntry.Name]; !ok {
|
||||
marks[dirEntry.Name] = struct{}{}
|
||||
deduplicatedDirEntries = append(deduplicatedDirEntries, dirEntry)
|
||||
}
|
||||
@@ -50,20 +53,25 @@ func deduplicate(dirEntries []DirEntry) []DirEntry {
|
||||
// 3. For filterType dir:
|
||||
// dir entry: A/B/C/<deviceName>
|
||||
// all entries: A/B/C/<deviceName>/*
|
||||
func FilterDirForPerDevConfigs(dirEntries []DirEntry, deviceName, configPath string, filterType portainer.PerDevConfigsFilterType) []DirEntry {
|
||||
func FilterDirForPerDevConfigs(dirEntries []DirEntry, deviceName, configPath string, filterType portainer.PerDevConfigsFilterType) ([]DirEntry, []string) {
|
||||
var filteredDirEntries []DirEntry
|
||||
|
||||
var envFiles []string
|
||||
|
||||
for _, dirEntry := range dirEntries {
|
||||
if shouldIncludeEntry(dirEntry, deviceName, configPath, filterType) {
|
||||
filteredDirEntries = append(filteredDirEntries, dirEntry)
|
||||
|
||||
if shouldParseEnvVars(dirEntry, deviceName, configPath, filterType) {
|
||||
envFiles = append(envFiles, dirEntry.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filteredDirEntries
|
||||
return filteredDirEntries, envFiles
|
||||
}
|
||||
|
||||
func shouldIncludeEntry(dirEntry DirEntry, deviceName, configPath string, filterType portainer.PerDevConfigsFilterType) bool {
|
||||
|
||||
// Include all entries outside of dir A
|
||||
if !isInConfigDir(dirEntry, configPath) {
|
||||
return true
|
||||
@@ -120,6 +128,15 @@ func shouldIncludeDir(dirEntry DirEntry, deviceName, configPath string) bool {
|
||||
return strings.HasPrefix(dirEntry.Name, filterPrefix)
|
||||
}
|
||||
|
||||
func shouldParseEnvVars(dirEntry DirEntry, deviceName, configPath string, filterType portainer.PerDevConfigsFilterType) bool {
|
||||
if !dirEntry.IsFile {
|
||||
return false
|
||||
}
|
||||
|
||||
return isInConfigDir(dirEntry, configPath) &&
|
||||
filepath.Base(dirEntry.Name) == deviceName+".env"
|
||||
}
|
||||
|
||||
func appendTailSeparator(path string) string {
|
||||
return fmt.Sprintf("%s%c", path, os.PathSeparator)
|
||||
}
|
||||
|
||||
@@ -4,14 +4,17 @@ import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMultiFilterDirForPerDevConfigs(t *testing.T) {
|
||||
type args struct {
|
||||
dirEntries []DirEntry
|
||||
configPath string
|
||||
multiFilterArgs MultiFilterArgs
|
||||
f := func(dirEntries []DirEntry, configPath string, multiFilterArgs MultiFilterArgs, wantDirEntries []DirEntry) {
|
||||
t.Helper()
|
||||
|
||||
dirEntries, _ = MultiFilterDirForPerDevConfigs(dirEntries, configPath, multiFilterArgs)
|
||||
require.Equal(t, wantDirEntries, dirEntries)
|
||||
}
|
||||
|
||||
baseDirEntries := []DirEntry{
|
||||
@@ -26,69 +29,75 @@ func TestMultiFilterDirForPerDevConfigs(t *testing.T) {
|
||||
{"configs/folder2/config2", "", true, 420},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []DirEntry
|
||||
}{
|
||||
{
|
||||
name: "filter file1",
|
||||
args: args{
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{{"file1", portainer.PerDevConfigsTypeFile}},
|
||||
},
|
||||
want: []DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[3]},
|
||||
// Filter file1
|
||||
f(
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{{"file1", portainer.PerDevConfigsTypeFile}},
|
||||
[]DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[3]},
|
||||
)
|
||||
|
||||
// Filter folder1
|
||||
f(
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{{"folder1", portainer.PerDevConfigsTypeDir}},
|
||||
[]DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[5], baseDirEntries[6]},
|
||||
)
|
||||
|
||||
// Filter file1 and folder1
|
||||
f(
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{{"folder1", portainer.PerDevConfigsTypeDir}},
|
||||
[]DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[5], baseDirEntries[6]},
|
||||
)
|
||||
|
||||
// Filter file1 and file2
|
||||
f(
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{
|
||||
{"file1", portainer.PerDevConfigsTypeFile},
|
||||
{"file2", portainer.PerDevConfigsTypeFile},
|
||||
},
|
||||
{
|
||||
name: "filter folder1",
|
||||
args: args{
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{{"folder1", portainer.PerDevConfigsTypeDir}},
|
||||
},
|
||||
want: []DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[5], baseDirEntries[6]},
|
||||
},
|
||||
{
|
||||
name: "filter file1 and folder1",
|
||||
args: args{
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{{"folder1", portainer.PerDevConfigsTypeDir}},
|
||||
},
|
||||
want: []DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[5], baseDirEntries[6]},
|
||||
},
|
||||
{
|
||||
name: "filter file1 and file2",
|
||||
args: args{
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{
|
||||
{"file1", portainer.PerDevConfigsTypeFile},
|
||||
{"file2", portainer.PerDevConfigsTypeFile},
|
||||
},
|
||||
},
|
||||
want: []DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[3], baseDirEntries[4]},
|
||||
},
|
||||
{
|
||||
name: "filter folder1 and folder2",
|
||||
args: args{
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{
|
||||
{"folder1", portainer.PerDevConfigsTypeDir},
|
||||
{"folder2", portainer.PerDevConfigsTypeDir},
|
||||
},
|
||||
},
|
||||
want: []DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[5], baseDirEntries[6], baseDirEntries[7], baseDirEntries[8]},
|
||||
[]DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[3], baseDirEntries[4]},
|
||||
)
|
||||
|
||||
// Filter folder1 and folder2
|
||||
f(
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{
|
||||
{"folder1", portainer.PerDevConfigsTypeDir},
|
||||
{"folder2", portainer.PerDevConfigsTypeDir},
|
||||
},
|
||||
[]DirEntry{baseDirEntries[0], baseDirEntries[1], baseDirEntries[2], baseDirEntries[5], baseDirEntries[6], baseDirEntries[7], baseDirEntries[8]},
|
||||
)
|
||||
}
|
||||
|
||||
func TestMultiFilterDirForPerDevConfigsEnvFiles(t *testing.T) {
|
||||
f := func(dirEntries []DirEntry, configPath string, multiFilterArgs MultiFilterArgs, wantEnvFiles []string) {
|
||||
t.Helper()
|
||||
|
||||
_, envFiles := MultiFilterDirForPerDevConfigs(dirEntries, configPath, multiFilterArgs)
|
||||
require.Equal(t, wantEnvFiles, envFiles)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, MultiFilterDirForPerDevConfigs(tt.args.dirEntries, tt.args.configPath, tt.args.multiFilterArgs), "MultiFilterDirForPerDevConfigs(%v, %v, %v)", tt.args.dirEntries, tt.args.configPath, tt.args.multiFilterArgs)
|
||||
})
|
||||
baseDirEntries := []DirEntry{
|
||||
{".env", "", true, 420},
|
||||
{"docker-compose.yaml", "", true, 420},
|
||||
{"configs", "", false, 420},
|
||||
{"configs/edge-id/edge-id.env", "", true, 420},
|
||||
}
|
||||
|
||||
f(
|
||||
baseDirEntries,
|
||||
"configs",
|
||||
MultiFilterArgs{{"edge-id", portainer.PerDevConfigsTypeDir}},
|
||||
[]string{"configs/edge-id/edge-id.env"},
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
func TestIsInConfigDir(t *testing.T) {
|
||||
|
||||
@@ -60,15 +60,9 @@ func NewAzureClient() *azureClient {
|
||||
}
|
||||
|
||||
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
|
||||
tlsConfig := crypto.CreateTLSConfiguration()
|
||||
|
||||
if insecureSkipVerify {
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
httpsCli := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
Timeout: 300 * time.Second,
|
||||
|
||||
@@ -58,7 +58,15 @@ func TestService_ClonePublicRepository_Azure(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dst := t.TempDir()
|
||||
repositoryUrl := fmt.Sprintf(tt.args.repositoryURLFormat, tt.args.password)
|
||||
err := service.CloneRepository(dst, repositoryUrl, tt.args.referenceName, "", "", false)
|
||||
err := service.CloneRepository(
|
||||
dst,
|
||||
repositoryUrl,
|
||||
tt.args.referenceName,
|
||||
"",
|
||||
"",
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.FileExists(t, filepath.Join(dst, "README.md"))
|
||||
})
|
||||
@@ -73,7 +81,15 @@ func TestService_ClonePrivateRepository_Azure(t *testing.T) {
|
||||
|
||||
dst := t.TempDir()
|
||||
|
||||
err := service.CloneRepository(dst, privateAzureRepoURL, "refs/heads/main", "", pat, false)
|
||||
err := service.CloneRepository(
|
||||
dst,
|
||||
privateAzureRepoURL,
|
||||
"refs/heads/main",
|
||||
"",
|
||||
pat,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.FileExists(t, filepath.Join(dst, "README.md"))
|
||||
}
|
||||
@@ -84,7 +100,14 @@ func TestService_LatestCommitID_Azure(t *testing.T) {
|
||||
pat := getRequiredValue(t, "AZURE_DEVOPS_PAT")
|
||||
service := NewService(context.TODO())
|
||||
|
||||
id, err := service.LatestCommitID(privateAzureRepoURL, "refs/heads/main", "", pat, false)
|
||||
id, err := service.LatestCommitID(
|
||||
privateAzureRepoURL,
|
||||
"refs/heads/main",
|
||||
"",
|
||||
pat,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
|
||||
}
|
||||
@@ -96,7 +119,14 @@ func TestService_ListRefs_Azure(t *testing.T) {
|
||||
username := getRequiredValue(t, "AZURE_DEVOPS_USERNAME")
|
||||
service := NewService(context.TODO())
|
||||
|
||||
refs, err := service.ListRefs(privateAzureRepoURL, username, accessToken, false, false)
|
||||
refs, err := service.ListRefs(
|
||||
privateAzureRepoURL,
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
}
|
||||
@@ -108,8 +138,8 @@ func TestService_ListRefs_Azure_Concurrently(t *testing.T) {
|
||||
username := getRequiredValue(t, "AZURE_DEVOPS_USERNAME")
|
||||
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
|
||||
|
||||
go service.ListRefs(privateAzureRepoURL, username, accessToken, false, false)
|
||||
service.ListRefs(privateAzureRepoURL, username, accessToken, false, false)
|
||||
go service.ListRefs(privateAzureRepoURL, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
service.ListRefs(privateAzureRepoURL, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
@@ -247,7 +277,17 @@ func TestService_ListFiles_Azure(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
paths, err := service.ListFiles(tt.args.repositoryUrl, tt.args.referenceName, tt.args.username, tt.args.password, false, false, tt.extensions, false)
|
||||
paths, err := service.ListFiles(
|
||||
tt.args.repositoryUrl,
|
||||
tt.args.referenceName,
|
||||
tt.args.username,
|
||||
tt.args.password,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
tt.extensions,
|
||||
false,
|
||||
)
|
||||
if tt.expect.shouldFail {
|
||||
assert.Error(t, err)
|
||||
if tt.expect.err != nil {
|
||||
@@ -270,8 +310,28 @@ func TestService_ListFiles_Azure_Concurrently(t *testing.T) {
|
||||
username := getRequiredValue(t, "AZURE_DEVOPS_USERNAME")
|
||||
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
|
||||
|
||||
go service.ListFiles(privateAzureRepoURL, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
service.ListFiles(privateAzureRepoURL, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
go service.ListFiles(
|
||||
privateAzureRepoURL,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
service.ListFiles(
|
||||
privateAzureRepoURL,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -234,6 +235,8 @@ func Test_isAzureUrl(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
type args struct {
|
||||
options baseOption
|
||||
}
|
||||
@@ -308,6 +311,8 @@ func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_azureDownloader_latestCommitID(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"count": 1,
|
||||
|
||||
@@ -19,6 +19,7 @@ type CloneOptions struct {
|
||||
ReferenceName string
|
||||
Username string
|
||||
Password string
|
||||
AuthType gittypes.GitCredentialAuthType
|
||||
// TLSSkipVerify skips SSL verification when cloning the Git repository
|
||||
TLSSkipVerify bool `example:"false"`
|
||||
}
|
||||
@@ -42,7 +43,15 @@ func CloneWithBackup(gitService portainer.GitService, fileService portainer.File
|
||||
|
||||
cleanUp = true
|
||||
|
||||
if err := gitService.CloneRepository(options.ProjectPath, options.URL, options.ReferenceName, options.Username, options.Password, options.TLSSkipVerify); err != nil {
|
||||
if err := gitService.CloneRepository(
|
||||
options.ProjectPath,
|
||||
options.URL,
|
||||
options.ReferenceName,
|
||||
options.Username,
|
||||
options.Password,
|
||||
options.AuthType,
|
||||
options.TLSSkipVerify,
|
||||
); err != nil {
|
||||
cleanUp = false
|
||||
if err := filesystem.MoveDirectory(backupProjectPath, options.ProjectPath, false); err != nil {
|
||||
log.Warn().Err(err).Msg("failed restoring backup folder")
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"strings"
|
||||
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/go-git/go-git/v5"
|
||||
"github.com/go-git/go-git/v5/config"
|
||||
"github.com/go-git/go-git/v5/plumbing"
|
||||
"github.com/go-git/go-git/v5/plumbing/filemode"
|
||||
"github.com/go-git/go-git/v5/plumbing/object"
|
||||
"github.com/go-git/go-git/v5/plumbing/transport"
|
||||
githttp "github.com/go-git/go-git/v5/plumbing/transport/http"
|
||||
"github.com/go-git/go-git/v5/storage/memory"
|
||||
"github.com/pkg/errors"
|
||||
@@ -33,7 +35,7 @@ func (c *gitClient) download(ctx context.Context, dst string, opt cloneOption) e
|
||||
URL: opt.repositoryUrl,
|
||||
Depth: opt.depth,
|
||||
InsecureSkipTLS: opt.tlsSkipVerify,
|
||||
Auth: getAuth(opt.username, opt.password),
|
||||
Auth: getAuth(opt.authType, opt.username, opt.password),
|
||||
Tags: git.NoTags,
|
||||
}
|
||||
|
||||
@@ -51,7 +53,10 @@ func (c *gitClient) download(ctx context.Context, dst string, opt cloneOption) e
|
||||
}
|
||||
|
||||
if !c.preserveGitDirectory {
|
||||
os.RemoveAll(filepath.Join(dst, ".git"))
|
||||
err := os.RemoveAll(filepath.Join(dst, ".git"))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to remove .git directory")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -64,7 +69,7 @@ func (c *gitClient) latestCommitID(ctx context.Context, opt fetchOption) (string
|
||||
})
|
||||
|
||||
listOptions := &git.ListOptions{
|
||||
Auth: getAuth(opt.username, opt.password),
|
||||
Auth: getAuth(opt.authType, opt.username, opt.password),
|
||||
InsecureSkipTLS: opt.tlsSkipVerify,
|
||||
}
|
||||
|
||||
@@ -94,7 +99,23 @@ func (c *gitClient) latestCommitID(ctx context.Context, opt fetchOption) (string
|
||||
return "", errors.Errorf("could not find ref %q in the repository", opt.referenceName)
|
||||
}
|
||||
|
||||
func getAuth(username, password string) *githttp.BasicAuth {
|
||||
func getAuth(authType gittypes.GitCredentialAuthType, username, password string) transport.AuthMethod {
|
||||
if password == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case gittypes.GitCredentialAuthType_Basic:
|
||||
return getBasicAuth(username, password)
|
||||
case gittypes.GitCredentialAuthType_Token:
|
||||
return getTokenAuth(password)
|
||||
default:
|
||||
log.Warn().Msg("unknown git credentials authorization type, defaulting to None")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getBasicAuth(username, password string) *githttp.BasicAuth {
|
||||
if password != "" {
|
||||
if username == "" {
|
||||
username = "token"
|
||||
@@ -108,6 +129,15 @@ func getAuth(username, password string) *githttp.BasicAuth {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTokenAuth(token string) *githttp.TokenAuth {
|
||||
if token != "" {
|
||||
return &githttp.TokenAuth{
|
||||
Token: token,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gitClient) listRefs(ctx context.Context, opt baseOption) ([]string, error) {
|
||||
rem := git.NewRemote(memory.NewStorage(), &config.RemoteConfig{
|
||||
Name: "origin",
|
||||
@@ -115,7 +145,7 @@ func (c *gitClient) listRefs(ctx context.Context, opt baseOption) ([]string, err
|
||||
})
|
||||
|
||||
listOptions := &git.ListOptions{
|
||||
Auth: getAuth(opt.username, opt.password),
|
||||
Auth: getAuth(opt.authType, opt.username, opt.password),
|
||||
InsecureSkipTLS: opt.tlsSkipVerify,
|
||||
}
|
||||
|
||||
@@ -143,7 +173,7 @@ func (c *gitClient) listFiles(ctx context.Context, opt fetchOption) ([]string, e
|
||||
Depth: 1,
|
||||
SingleBranch: true,
|
||||
ReferenceName: plumbing.ReferenceName(opt.referenceName),
|
||||
Auth: getAuth(opt.username, opt.password),
|
||||
Auth: getAuth(opt.authType, opt.username, opt.password),
|
||||
InsecureSkipTLS: opt.tlsSkipVerify,
|
||||
Tags: git.NoTags,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package git
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -24,7 +26,15 @@ func TestService_ClonePrivateRepository_GitHub(t *testing.T) {
|
||||
dst := t.TempDir()
|
||||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
err := service.CloneRepository(dst, repositoryUrl, "refs/heads/main", username, accessToken, false)
|
||||
err := service.CloneRepository(
|
||||
dst,
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.FileExists(t, filepath.Join(dst, "README.md"))
|
||||
}
|
||||
@@ -37,7 +47,14 @@ func TestService_LatestCommitID_GitHub(t *testing.T) {
|
||||
service := newService(context.TODO(), 0, 0)
|
||||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
id, err := service.LatestCommitID(repositoryUrl, "refs/heads/main", username, accessToken, false)
|
||||
id, err := service.LatestCommitID(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
|
||||
}
|
||||
@@ -50,7 +67,7 @@ func TestService_ListRefs_GitHub(t *testing.T) {
|
||||
service := newService(context.TODO(), 0, 0)
|
||||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, false, false)
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
}
|
||||
@@ -63,8 +80,8 @@ func TestService_ListRefs_Github_Concurrently(t *testing.T) {
|
||||
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
|
||||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
go service.ListRefs(repositoryUrl, username, accessToken, false, false)
|
||||
service.ListRefs(repositoryUrl, username, accessToken, false, false)
|
||||
go service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
@@ -202,7 +219,17 @@ func TestService_ListFiles_GitHub(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
paths, err := service.ListFiles(tt.args.repositoryUrl, tt.args.referenceName, tt.args.username, tt.args.password, false, false, tt.extensions, false)
|
||||
paths, err := service.ListFiles(
|
||||
tt.args.repositoryUrl,
|
||||
tt.args.referenceName,
|
||||
tt.args.username,
|
||||
tt.args.password,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
tt.extensions,
|
||||
false,
|
||||
)
|
||||
if tt.expect.shouldFail {
|
||||
assert.Error(t, err)
|
||||
if tt.expect.err != nil {
|
||||
@@ -226,8 +253,28 @@ func TestService_ListFiles_Github_Concurrently(t *testing.T) {
|
||||
username := getRequiredValue(t, "GITHUB_USERNAME")
|
||||
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
|
||||
|
||||
go service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
go service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
@@ -240,8 +287,18 @@ func TestService_purgeCache_Github(t *testing.T) {
|
||||
username := getRequiredValue(t, "GITHUB_USERNAME")
|
||||
service := NewService(context.TODO())
|
||||
|
||||
service.ListRefs(repositoryUrl, username, accessToken, false, false)
|
||||
service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
assert.Equal(t, 1, service.repoFileCache.Len())
|
||||
@@ -261,8 +318,18 @@ func TestService_purgeCacheByTTL_Github(t *testing.T) {
|
||||
// 40*timeout is designed for giving enough time for ListRefs and ListFiles to cache the result
|
||||
service := newService(context.TODO(), 2, 40*timeout)
|
||||
|
||||
service.ListRefs(repositoryUrl, username, accessToken, false, false)
|
||||
service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
assert.Equal(t, 1, service.repoFileCache.Len())
|
||||
|
||||
@@ -293,12 +360,12 @@ func TestService_HardRefresh_ListRefs_GitHub(t *testing.T) {
|
||||
service := newService(context.TODO(), 2, 0)
|
||||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, false, false)
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", false, false)
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
}
|
||||
@@ -311,26 +378,46 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
|
||||
service := newService(context.TODO(), 2, 0)
|
||||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, false, false)
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
|
||||
files, err := service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
files, err := service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(files), 1)
|
||||
assert.Equal(t, 1, service.repoFileCache.Len())
|
||||
|
||||
files, err = service.ListFiles(repositoryUrl, "refs/heads/test", username, accessToken, false, false, []string{}, false)
|
||||
files, err = service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/test",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(files), 1)
|
||||
assert.Equal(t, 2, service.repoFileCache.Len())
|
||||
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", false, false)
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", true, false)
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, true, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
// The relevant file caches should be removed too
|
||||
@@ -344,12 +431,72 @@ func TestService_HardRefresh_ListFiles_GitHub(t *testing.T) {
|
||||
accessToken := getRequiredValue(t, "GITHUB_PAT")
|
||||
username := getRequiredValue(t, "GITHUB_USERNAME")
|
||||
repositoryUrl := privateGitRepoURL
|
||||
files, err := service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
|
||||
files, err := service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(files), 1)
|
||||
assert.Equal(t, 1, service.repoFileCache.Len())
|
||||
|
||||
_, err = service.ListFiles(repositoryUrl, "refs/heads/main", username, "fake-token", false, true, []string{}, false)
|
||||
_, err = service.ListFiles(
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
"fake-token",
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
true,
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 0, service.repoFileCache.Len())
|
||||
}
|
||||
|
||||
func TestService_CloneRepository_TokenAuth(t *testing.T) {
|
||||
ensureIntegrationTest(t)
|
||||
|
||||
service := newService(context.TODO(), 2, 0)
|
||||
var requests []*http.Request
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requests = append(requests, r)
|
||||
}))
|
||||
accessToken := "test_access_token"
|
||||
username := "test_username"
|
||||
repositoryUrl := testServer.URL
|
||||
|
||||
// Since we aren't hitting a real git server we ignore the error
|
||||
_ = service.CloneRepository(
|
||||
"test_dir",
|
||||
repositoryUrl,
|
||||
"refs/heads/main",
|
||||
username,
|
||||
accessToken,
|
||||
gittypes.GitCredentialAuthType_Token,
|
||||
false,
|
||||
)
|
||||
|
||||
testServer.Close()
|
||||
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request sent but got %d", len(requests))
|
||||
}
|
||||
|
||||
gotAuthHeader := requests[0].Header.Get("Authorization")
|
||||
if gotAuthHeader == "" {
|
||||
t.Fatal("no Authorization header in git request")
|
||||
}
|
||||
|
||||
expectedAuthHeader := "Bearer test_access_token"
|
||||
if gotAuthHeader != expectedAuthHeader {
|
||||
t.Fatalf("expected Authorization header %q but got %q", expectedAuthHeader, gotAuthHeader)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,9 +38,9 @@ func Test_ClonePublicRepository_Shallow(t *testing.T) {
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Logf("Cloning into %s", dir)
|
||||
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", false)
|
||||
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, getCommitHistoryLength(t, err, dir), "cloned repo has incorrect depth")
|
||||
assert.Equal(t, 1, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
|
||||
}
|
||||
|
||||
func Test_ClonePublicRepository_NoGitDirectory(t *testing.T) {
|
||||
@@ -50,7 +50,7 @@ func Test_ClonePublicRepository_NoGitDirectory(t *testing.T) {
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Logf("Cloning into %s", dir)
|
||||
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", false)
|
||||
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
|
||||
assert.NoError(t, err)
|
||||
assert.NoDirExists(t, filepath.Join(dir, ".git"))
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func Test_cloneRepository(t *testing.T) {
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4, getCommitHistoryLength(t, err, dir), "cloned repo has incorrect depth")
|
||||
assert.Equal(t, 4, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
|
||||
}
|
||||
|
||||
func Test_latestCommitID(t *testing.T) {
|
||||
@@ -84,7 +84,7 @@ func Test_latestCommitID(t *testing.T) {
|
||||
repositoryURL := setup(t)
|
||||
referenceName := "refs/heads/main"
|
||||
|
||||
id, err := service.LatestCommitID(repositoryURL, referenceName, "", "", false)
|
||||
id, err := service.LatestCommitID(repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "68dcaa7bd452494043c64252ab90db0f98ecf8d2", id)
|
||||
@@ -95,7 +95,7 @@ func Test_ListRefs(t *testing.T) {
|
||||
|
||||
repositoryURL := setup(t)
|
||||
|
||||
fs, err := service.ListRefs(repositoryURL, "", "", false, false)
|
||||
fs, err := service.ListRefs(repositoryURL, "", "", gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"refs/heads/main"}, fs)
|
||||
@@ -107,13 +107,23 @@ func Test_ListFiles(t *testing.T) {
|
||||
repositoryURL := setup(t)
|
||||
referenceName := "refs/heads/main"
|
||||
|
||||
fs, err := service.ListFiles(repositoryURL, referenceName, "", "", false, false, []string{".yml"}, false)
|
||||
fs, err := service.ListFiles(
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
"",
|
||||
"",
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
false,
|
||||
[]string{".yml"},
|
||||
false,
|
||||
)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"docker-compose.yml"}, fs)
|
||||
}
|
||||
|
||||
func getCommitHistoryLength(t *testing.T, err error, dir string) int {
|
||||
func getCommitHistoryLength(t *testing.T, dir string) int {
|
||||
repo, err := git.PlainOpen(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("can't open a git repo at %s with error %v", dir, err)
|
||||
@@ -125,11 +135,10 @@ func getCommitHistoryLength(t *testing.T, err error, dir string) int {
|
||||
}
|
||||
|
||||
count := 0
|
||||
err = iter.ForEach(func(_ *object.Commit) error {
|
||||
if err := iter.ForEach(func(_ *object.Commit) error {
|
||||
count++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
t.Fatalf("can't iterate over the commit history with error %v", err)
|
||||
}
|
||||
|
||||
@@ -255,7 +264,7 @@ func Test_listFilesPrivateRepository(t *testing.T) {
|
||||
name: "list tree with real repository and head ref but no credential",
|
||||
args: fetchOption{
|
||||
baseOption: baseOption{
|
||||
repositoryUrl: privateGitRepoURL + "fake",
|
||||
repositoryUrl: privateGitRepoURL,
|
||||
username: "",
|
||||
password: "",
|
||||
},
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
@@ -22,6 +23,7 @@ type baseOption struct {
|
||||
repositoryUrl string
|
||||
username string
|
||||
password string
|
||||
authType gittypes.GitCredentialAuthType
|
||||
tlsSkipVerify bool
|
||||
}
|
||||
|
||||
@@ -123,13 +125,22 @@ func (service *Service) timerHasStopped() bool {
|
||||
|
||||
// CloneRepository clones a git repository using the specified URL in the specified
|
||||
// destination folder.
|
||||
func (service *Service) CloneRepository(destination, repositoryURL, referenceName, username, password string, tlsSkipVerify bool) error {
|
||||
func (service *Service) CloneRepository(
|
||||
destination,
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
username,
|
||||
password string,
|
||||
authType gittypes.GitCredentialAuthType,
|
||||
tlsSkipVerify bool,
|
||||
) error {
|
||||
options := cloneOption{
|
||||
fetchOption: fetchOption{
|
||||
baseOption: baseOption{
|
||||
repositoryUrl: repositoryURL,
|
||||
username: username,
|
||||
password: password,
|
||||
authType: authType,
|
||||
tlsSkipVerify: tlsSkipVerify,
|
||||
},
|
||||
referenceName: referenceName,
|
||||
@@ -155,12 +166,20 @@ func (service *Service) cloneRepository(destination string, options cloneOption)
|
||||
}
|
||||
|
||||
// LatestCommitID returns SHA1 of the latest commit of the specified reference
|
||||
func (service *Service) LatestCommitID(repositoryURL, referenceName, username, password string, tlsSkipVerify bool) (string, error) {
|
||||
func (service *Service) LatestCommitID(
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
username,
|
||||
password string,
|
||||
authType gittypes.GitCredentialAuthType,
|
||||
tlsSkipVerify bool,
|
||||
) (string, error) {
|
||||
options := fetchOption{
|
||||
baseOption: baseOption{
|
||||
repositoryUrl: repositoryURL,
|
||||
username: username,
|
||||
password: password,
|
||||
authType: authType,
|
||||
tlsSkipVerify: tlsSkipVerify,
|
||||
},
|
||||
referenceName: referenceName,
|
||||
@@ -170,7 +189,14 @@ func (service *Service) LatestCommitID(repositoryURL, referenceName, username, p
|
||||
}
|
||||
|
||||
// ListRefs will list target repository's references without cloning the repository
|
||||
func (service *Service) ListRefs(repositoryURL, username, password string, hardRefresh bool, tlsSkipVerify bool) ([]string, error) {
|
||||
func (service *Service) ListRefs(
|
||||
repositoryURL,
|
||||
username,
|
||||
password string,
|
||||
authType gittypes.GitCredentialAuthType,
|
||||
hardRefresh bool,
|
||||
tlsSkipVerify bool,
|
||||
) ([]string, error) {
|
||||
refCacheKey := generateCacheKey(repositoryURL, username, password, strconv.FormatBool(tlsSkipVerify))
|
||||
if service.cacheEnabled && hardRefresh {
|
||||
// Should remove the cache explicitly, so that the following normal list can show the correct result
|
||||
@@ -196,6 +222,7 @@ func (service *Service) ListRefs(repositoryURL, username, password string, hardR
|
||||
repositoryUrl: repositoryURL,
|
||||
username: username,
|
||||
password: password,
|
||||
authType: authType,
|
||||
tlsSkipVerify: tlsSkipVerify,
|
||||
}
|
||||
|
||||
@@ -215,18 +242,62 @@ var singleflightGroup = &singleflight.Group{}
|
||||
|
||||
// ListFiles will list all the files of the target repository with specific extensions.
|
||||
// If extension is not provided, it will list all the files under the target repository
|
||||
func (service *Service) ListFiles(repositoryURL, referenceName, username, password string, dirOnly, hardRefresh bool, includedExts []string, tlsSkipVerify bool) ([]string, error) {
|
||||
repoKey := generateCacheKey(repositoryURL, referenceName, username, password, strconv.FormatBool(tlsSkipVerify), strconv.FormatBool(dirOnly))
|
||||
func (service *Service) ListFiles(
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
username,
|
||||
password string,
|
||||
authType gittypes.GitCredentialAuthType,
|
||||
dirOnly,
|
||||
hardRefresh bool,
|
||||
includedExts []string,
|
||||
tlsSkipVerify bool,
|
||||
) ([]string, error) {
|
||||
repoKey := generateCacheKey(
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
username,
|
||||
password,
|
||||
strconv.FormatBool(tlsSkipVerify),
|
||||
strconv.Itoa(int(authType)),
|
||||
strconv.FormatBool(dirOnly),
|
||||
)
|
||||
|
||||
fs, err, _ := singleflightGroup.Do(repoKey, func() (any, error) {
|
||||
return service.listFiles(repositoryURL, referenceName, username, password, dirOnly, hardRefresh, tlsSkipVerify)
|
||||
return service.listFiles(
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
username,
|
||||
password,
|
||||
authType,
|
||||
dirOnly,
|
||||
hardRefresh,
|
||||
tlsSkipVerify,
|
||||
)
|
||||
})
|
||||
|
||||
return filterFiles(fs.([]string), includedExts), err
|
||||
}
|
||||
|
||||
func (service *Service) listFiles(repositoryURL, referenceName, username, password string, dirOnly, hardRefresh bool, tlsSkipVerify bool) ([]string, error) {
|
||||
repoKey := generateCacheKey(repositoryURL, referenceName, username, password, strconv.FormatBool(tlsSkipVerify), strconv.FormatBool(dirOnly))
|
||||
func (service *Service) listFiles(
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
username,
|
||||
password string,
|
||||
authType gittypes.GitCredentialAuthType,
|
||||
dirOnly,
|
||||
hardRefresh bool,
|
||||
tlsSkipVerify bool,
|
||||
) ([]string, error) {
|
||||
repoKey := generateCacheKey(
|
||||
repositoryURL,
|
||||
referenceName,
|
||||
username,
|
||||
password,
|
||||
strconv.FormatBool(tlsSkipVerify),
|
||||
strconv.Itoa(int(authType)),
|
||||
strconv.FormatBool(dirOnly),
|
||||
)
|
||||
|
||||
if service.cacheEnabled && hardRefresh {
|
||||
// Should remove the cache explicitly, so that the following normal list can show the correct result
|
||||
@@ -247,6 +318,7 @@ func (service *Service) listFiles(repositoryURL, referenceName, username, passwo
|
||||
repositoryUrl: repositoryURL,
|
||||
username: username,
|
||||
password: password,
|
||||
authType: authType,
|
||||
tlsSkipVerify: tlsSkipVerify,
|
||||
},
|
||||
referenceName: referenceName,
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
package gittypes
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIncorrectRepositoryURL = errors.New("git repository could not be found, please ensure that the URL is correct")
|
||||
ErrAuthenticationFailure = errors.New("authentication failed, please ensure that the git credentials are correct")
|
||||
)
|
||||
|
||||
type GitCredentialAuthType int
|
||||
|
||||
const (
|
||||
GitCredentialAuthType_Basic GitCredentialAuthType = iota
|
||||
GitCredentialAuthType_Token
|
||||
)
|
||||
|
||||
// RepoConfig represents a configuration for a repo
|
||||
type RepoConfig struct {
|
||||
// The repo url
|
||||
@@ -24,10 +33,11 @@ type RepoConfig struct {
|
||||
}
|
||||
|
||||
type GitAuthentication struct {
|
||||
Username string
|
||||
Password string
|
||||
Username string
|
||||
Password string
|
||||
AuthorizationType GitCredentialAuthType
|
||||
// Git credentials identifier when the value is not 0
|
||||
// When the value is 0, Username and Password are set without using saved credential
|
||||
// When the value is 0, Username, Password, and Authtype are set without using saved credential
|
||||
// This is introduced since 2.15.0
|
||||
GitCredentialID int `example:"0"`
|
||||
}
|
||||
|
||||
@@ -29,7 +29,14 @@ func UpdateGitObject(gitService portainer.GitService, objId string, gitConfig *g
|
||||
return false, "", errors.WithMessagef(err, "failed to get credentials for %v", objId)
|
||||
}
|
||||
|
||||
newHash, err := gitService.LatestCommitID(gitConfig.URL, gitConfig.ReferenceName, username, password, gitConfig.TLSSkipVerify)
|
||||
newHash, err := gitService.LatestCommitID(
|
||||
gitConfig.URL,
|
||||
gitConfig.ReferenceName,
|
||||
username,
|
||||
password,
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
gitConfig.TLSSkipVerify,
|
||||
)
|
||||
if err != nil {
|
||||
return false, "", errors.WithMessagef(err, "failed to fetch latest commit id of %v", objId)
|
||||
}
|
||||
@@ -62,6 +69,7 @@ func UpdateGitObject(gitService portainer.GitService, objId string, gitConfig *g
|
||||
cloneParams.auth = &gitAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
authType: gitConfig.Authentication.AuthorizationType,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,14 +97,31 @@ type cloneRepositoryParameters struct {
|
||||
}
|
||||
|
||||
type gitAuth struct {
|
||||
authType gittypes.GitCredentialAuthType
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
func cloneGitRepository(gitService portainer.GitService, cloneParams *cloneRepositoryParameters) error {
|
||||
if cloneParams.auth != nil {
|
||||
return gitService.CloneRepository(cloneParams.toDir, cloneParams.url, cloneParams.ref, cloneParams.auth.username, cloneParams.auth.password, cloneParams.tlsSkipVerify)
|
||||
return gitService.CloneRepository(
|
||||
cloneParams.toDir,
|
||||
cloneParams.url,
|
||||
cloneParams.ref,
|
||||
cloneParams.auth.username,
|
||||
cloneParams.auth.password,
|
||||
cloneParams.auth.authType,
|
||||
cloneParams.tlsSkipVerify,
|
||||
)
|
||||
}
|
||||
|
||||
return gitService.CloneRepository(cloneParams.toDir, cloneParams.url, cloneParams.ref, "", "", cloneParams.tlsSkipVerify)
|
||||
return gitService.CloneRepository(
|
||||
cloneParams.toDir,
|
||||
cloneParams.url,
|
||||
cloneParams.ref,
|
||||
"",
|
||||
"",
|
||||
gittypes.GitCredentialAuthType_Basic,
|
||||
cloneParams.tlsSkipVerify,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package update
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/asaskevich/govalidator"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
httperrors "github.com/portainer/portainer/api/http/errors"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
)
|
||||
|
||||
func ValidateAutoUpdateSettings(autoUpdate *portainer.AutoUpdateSettings) error {
|
||||
@@ -17,14 +17,18 @@ func ValidateAutoUpdateSettings(autoUpdate *portainer.AutoUpdateSettings) error
|
||||
return httperrors.NewInvalidPayloadError("Webhook or Interval must be provided")
|
||||
}
|
||||
|
||||
if autoUpdate.Webhook != "" && !govalidator.IsUUID(autoUpdate.Webhook) {
|
||||
if autoUpdate.Webhook != "" && !validate.IsUUID(autoUpdate.Webhook) {
|
||||
return httperrors.NewInvalidPayloadError("invalid Webhook format")
|
||||
}
|
||||
|
||||
if autoUpdate.Interval != "" {
|
||||
if _, err := time.ParseDuration(autoUpdate.Interval); err != nil {
|
||||
return httperrors.NewInvalidPayloadError("invalid Interval format")
|
||||
}
|
||||
if autoUpdate.Interval == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if d, err := time.ParseDuration(autoUpdate.Interval); err != nil {
|
||||
return httperrors.NewInvalidPayloadError("invalid Interval format")
|
||||
} else if d < time.Minute {
|
||||
return httperrors.NewInvalidPayloadError("interval must be at least 1 minute")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -23,6 +23,16 @@ func Test_ValidateAutoUpdate(t *testing.T) {
|
||||
value: &portainer.AutoUpdateSettings{Interval: "1dd2hh3mm"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "short interval value",
|
||||
value: &portainer.AutoUpdateSettings{Interval: "1s"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid webhook without interval",
|
||||
value: &portainer.AutoUpdateSettings{Webhook: "8dce8c2f-9ca1-482b-ad20-271e86536ada"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid auto update",
|
||||
value: &portainer.AutoUpdateSettings{
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user