Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c1977e0aa | ||
|
|
c47fd9f9ed | ||
|
|
767d1d1970 | ||
|
|
ef81e5c0e0 | ||
|
|
234b7a3d5e | ||
|
|
af49305e64 | ||
|
|
d181d1251c | ||
|
|
5f7db66e95 | ||
|
|
17378bdef6 | ||
|
|
010542ac1e | ||
|
|
1bb253479a | ||
|
|
f0a13a2ad1 | ||
|
|
f9b28aa0a1 | ||
|
|
d26e1b6983 | ||
|
|
7b00fdd208 | ||
|
|
14b998d270 | ||
|
|
605ff8c1da | ||
|
|
13f93f4262 | ||
|
|
16be5ed329 | ||
|
|
c6612898f3 | ||
|
|
564f34b0ba | ||
|
|
392fbdb4a7 | ||
|
|
a826c78786 | ||
|
|
a35f0607f1 | ||
|
|
081d32af0d | ||
|
|
4cc0b1f567 | ||
|
|
d4da7e1760 | ||
|
|
aced418880 | ||
|
|
614f42fe5a | ||
|
|
58736fe93b | ||
|
|
b78330b10d | ||
|
|
eed4a92ca8 | ||
|
|
0e7468a1e8 | ||
|
|
b807481f1c | ||
|
|
da27de2154 | ||
|
|
6743e4fbb2 | ||
|
|
b489ffaa63 | ||
|
|
6e12499d61 | ||
|
|
f7acbe16ba |
52
.air.toml
52
.air.toml
@@ -1,52 +0,0 @@
|
||||
root = "."
|
||||
testdata_dir = "testdata"
|
||||
tmp_dir = ".tmp"
|
||||
|
||||
[build]
|
||||
args_bin = []
|
||||
bin = "./dist/portainer"
|
||||
cmd = "SKIP_GO_GET=true make build-server"
|
||||
delay = 1000
|
||||
exclude_dir = []
|
||||
exclude_file = []
|
||||
exclude_regex = ["_test.go"]
|
||||
exclude_unchanged = false
|
||||
follow_symlink = false
|
||||
full_bin = "./dist/portainer --log-level=DEBUG"
|
||||
include_dir = ["api"]
|
||||
include_ext = ["go"]
|
||||
include_file = []
|
||||
kill_delay = "0s"
|
||||
log = "build-errors.log"
|
||||
poll = false
|
||||
poll_interval = 0
|
||||
post_cmd = []
|
||||
pre_cmd = []
|
||||
rerun = false
|
||||
rerun_delay = 500
|
||||
send_interrupt = false
|
||||
stop_on_error = false
|
||||
|
||||
[color]
|
||||
app = ""
|
||||
build = "yellow"
|
||||
main = "magenta"
|
||||
runner = "green"
|
||||
watcher = "cyan"
|
||||
|
||||
[log]
|
||||
main_only = false
|
||||
silent = false
|
||||
time = false
|
||||
|
||||
[misc]
|
||||
clean_on_exit = false
|
||||
|
||||
[proxy]
|
||||
app_port = 0
|
||||
enabled = false
|
||||
proxy_port = 0
|
||||
|
||||
[screen]
|
||||
clear_on_rebuild = false
|
||||
keep_scroll = true
|
||||
44
.codeclimate.yml
Normal file
44
.codeclimate.yml
Normal file
@@ -0,0 +1,44 @@
|
||||
version: "2"
|
||||
checks:
|
||||
argument-count:
|
||||
enabled: false
|
||||
complex-logic:
|
||||
enabled: false
|
||||
file-lines:
|
||||
enabled: false
|
||||
method-complexity:
|
||||
enabled: false
|
||||
method-count:
|
||||
enabled: false
|
||||
method-lines:
|
||||
enabled: false
|
||||
nested-control-flow:
|
||||
enabled: false
|
||||
return-statements:
|
||||
enabled: false
|
||||
similar-code:
|
||||
enabled: false
|
||||
identical-code:
|
||||
enabled: false
|
||||
plugins:
|
||||
gofmt:
|
||||
enabled: true
|
||||
eslint:
|
||||
enabled: true
|
||||
channel: "eslint-5"
|
||||
config:
|
||||
config: .eslintrc.yml
|
||||
exclude_patterns:
|
||||
- assets/
|
||||
- build/
|
||||
- dist/
|
||||
- distribution/
|
||||
- node_modules
|
||||
- test/
|
||||
- webpack/
|
||||
- gruntfile.js
|
||||
- webpack.config.js
|
||||
- api/
|
||||
- "!app/kubernetes/**"
|
||||
- .github/
|
||||
- .tmp/
|
||||
@@ -10,57 +10,39 @@ globals:
|
||||
extends:
|
||||
- 'eslint:recommended'
|
||||
- 'plugin:storybook/recommended'
|
||||
- 'plugin:import/typescript'
|
||||
- prettier
|
||||
|
||||
plugins:
|
||||
- import
|
||||
|
||||
parserOptions:
|
||||
ecmaVersion: latest
|
||||
ecmaVersion: 2018
|
||||
sourceType: module
|
||||
project: './tsconfig.json'
|
||||
ecmaFeatures:
|
||||
modules: true
|
||||
|
||||
rules:
|
||||
no-console: error
|
||||
no-alert: error
|
||||
no-control-regex: 'off'
|
||||
no-empty: warn
|
||||
no-empty-function: warn
|
||||
no-useless-escape: 'off'
|
||||
import/named: error
|
||||
import/order:
|
||||
[
|
||||
'error',
|
||||
{
|
||||
pathGroups:
|
||||
[
|
||||
{ pattern: '@@/**', group: 'internal', position: 'after' },
|
||||
{ pattern: '@/**', group: 'internal' },
|
||||
{ pattern: '{Kubernetes,Portainer,Agent,Azure,Docker}/**', group: 'internal' },
|
||||
],
|
||||
pathGroups: [{ pattern: '@/**', group: 'internal' }, { pattern: '{Kubernetes,Portainer,Agent,Azure,Docker}/**', group: 'internal' }],
|
||||
groups: ['builtin', 'external', 'internal', 'parent', 'sibling', 'index'],
|
||||
pathGroupsExcludedImportTypes: ['internal'],
|
||||
},
|
||||
]
|
||||
no-restricted-imports:
|
||||
- error
|
||||
- patterns:
|
||||
- group:
|
||||
- '@/react/test-utils/*'
|
||||
message: 'These utils are just for test files'
|
||||
|
||||
settings:
|
||||
'import/resolver':
|
||||
alias:
|
||||
map:
|
||||
- ['@@', './app/react/components']
|
||||
- ['@', './app']
|
||||
extensions: ['.js', '.ts', '.tsx']
|
||||
typescript: true
|
||||
node: true
|
||||
|
||||
overrides:
|
||||
- files:
|
||||
@@ -70,7 +52,6 @@ overrides:
|
||||
parser: '@typescript-eslint/parser'
|
||||
plugins:
|
||||
- '@typescript-eslint'
|
||||
- 'regex'
|
||||
extends:
|
||||
- airbnb
|
||||
- airbnb-typescript
|
||||
@@ -85,23 +66,13 @@ overrides:
|
||||
settings:
|
||||
react:
|
||||
version: 'detect'
|
||||
|
||||
rules:
|
||||
no-console: error
|
||||
import/order:
|
||||
[
|
||||
'error',
|
||||
{
|
||||
pathGroups: [{ pattern: '@@/**', group: 'internal', position: 'after' }, { pattern: '@/**', group: 'internal' }],
|
||||
groups: ['builtin', 'external', 'internal', 'parent', 'sibling', 'index'],
|
||||
'newlines-between': 'always',
|
||||
},
|
||||
]
|
||||
no-plusplus: off
|
||||
['error', { pathGroups: [{ pattern: '@/**', group: 'internal' }], groups: ['builtin', 'external', 'internal', 'parent', 'sibling', 'index'], 'newlines-between': 'always' }]
|
||||
func-style: [error, 'declaration']
|
||||
import/prefer-default-export: off
|
||||
no-use-before-define: 'off'
|
||||
'@typescript-eslint/no-use-before-define': ['error', { functions: false, 'allowNamedExports': true }]
|
||||
no-use-before-define: ['error', { functions: false }]
|
||||
'@typescript-eslint/no-use-before-define': ['error', { functions: false }]
|
||||
no-shadow: 'off'
|
||||
'@typescript-eslint/no-shadow': off
|
||||
jsx-a11y/no-autofocus: warn
|
||||
@@ -114,49 +85,21 @@ overrides:
|
||||
'@typescript-eslint/explicit-module-boundary-types': off
|
||||
'@typescript-eslint/no-unused-vars': 'error'
|
||||
'@typescript-eslint/no-explicit-any': 'error'
|
||||
'jsx-a11y/label-has-associated-control':
|
||||
- error
|
||||
- assert: either
|
||||
controlComponents:
|
||||
- Input
|
||||
- Checkbox
|
||||
'jsx-a11y/control-has-associated-label': off
|
||||
'jsx-a11y/label-has-associated-control': ['error', { 'assert': 'either' }]
|
||||
'react/function-component-definition': ['error', { 'namedComponents': 'function-declaration' }]
|
||||
'react/jsx-no-bind': off
|
||||
'no-await-in-loop': 'off'
|
||||
'react/jsx-no-useless-fragment': ['error', { allowExpressions: true }]
|
||||
'regex/invalid': ['error', [{ 'regex': '<Icon icon="(.*)"', 'message': 'Please directly import the `lucide-react` icon instead of using the string' }]]
|
||||
'@typescript-eslint/no-restricted-imports':
|
||||
- error
|
||||
- patterns:
|
||||
- group:
|
||||
- '@/react/test-utils/*'
|
||||
message: 'These utils are just for test files'
|
||||
overrides: # allow props spreading for hoc files
|
||||
- files:
|
||||
- app/**/with*.ts{,x}
|
||||
rules:
|
||||
'react/jsx-props-no-spreading': off
|
||||
- files:
|
||||
- app/**/*.test.*
|
||||
plugins:
|
||||
- '@vitest'
|
||||
extends:
|
||||
- 'plugin:@vitest/legacy-recommended'
|
||||
- 'plugin:jest/recommended'
|
||||
- 'plugin:jest/style'
|
||||
env:
|
||||
'@vitest/env': true
|
||||
'jest/globals': true
|
||||
rules:
|
||||
'react/jsx-no-constructed-context-values': off
|
||||
'@typescript-eslint/no-restricted-imports': off
|
||||
no-restricted-imports: off
|
||||
'react/jsx-props-no-spreading': off
|
||||
'@vitest/no-conditional-expect': warn
|
||||
'max-classes-per-file': off
|
||||
- files:
|
||||
- app/**/*.stories.*
|
||||
rules:
|
||||
'no-alert': off
|
||||
'@typescript-eslint/no-restricted-imports': off
|
||||
no-restricted-imports: off
|
||||
'react/jsx-props-no-spreading': off
|
||||
'storybook/no-renderer-packages': off
|
||||
|
||||
@@ -2,7 +2,4 @@
|
||||
cf5056d9c03b62d91a25c3b9127caac838695f98
|
||||
|
||||
# prettier v2
|
||||
42e7db0ae7897d3cb72b0ea1ecf57ee2dd694169
|
||||
|
||||
# tailwind prettier
|
||||
58d66d3142950bb90a7d85511c034ac9fabba9ba
|
||||
42e7db0ae7897d3cb72b0ea1ecf57ee2dd694169
|
||||
11
.github/DISCUSSION_TEMPLATE/help.yaml
vendored
11
.github/DISCUSSION_TEMPLATE/help.yaml
vendored
@@ -1,11 +0,0 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Before asking a question, make sure it hasn't been already asked and answered. You can search our [discussions](https://github.com/orgs/portainer/discussions) and [bug reports](https://github.com/portainer/portainer/issues) in GitHub. Also, be sure to check our [knowledge base](https://portal.portainer.io/knowledge) and [documentation](https://docs.portainer.io/) first.
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Ask a Question!
|
||||
validations:
|
||||
required: true
|
||||
38
.github/DISCUSSION_TEMPLATE/ideas.yaml
vendored
38
.github/DISCUSSION_TEMPLATE/ideas.yaml
vendored
@@ -1,38 +0,0 @@
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# Welcome!
|
||||
|
||||
Thanks for suggesting an idea for Portainer!
|
||||
|
||||
Before opening a new idea or feature request, make sure that we do not have any duplicates already open. You can ensure this by [searching this discussion category](https://github.com/orgs/portainer/discussions/categories/ideas). If there is a duplicate, please add a comment to the existing idea instead.
|
||||
|
||||
Also, be sure to check our [knowledge base](https://portal.portainer.io/knowledge) and [documentation](https://docs.portainer.io) as they may point you toward a solution.
|
||||
|
||||
**DO NOT FILE DUPLICATE REQUESTS.**
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Is your feature request related to a problem? Please describe
|
||||
description: Short list of what the feature request aims to address.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe the solution you'd like
|
||||
description: A clear and concise description of what you want to happen.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe alternatives you've considered
|
||||
description: A clear and concise description of any alternative solutions or features you've considered.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
validations:
|
||||
required: false
|
||||
54
.github/ISSUE_TEMPLATE/Bug_report.md
vendored
Normal file
54
.github/ISSUE_TEMPLATE/Bug_report.md
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a bug report
|
||||
title: ''
|
||||
labels: bug/need-confirmation, kind/bug
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
<!--
|
||||
|
||||
Thanks for reporting a bug for Portainer !
|
||||
|
||||
You can find more information about Portainer support framework policy here: https://www.portainer.io/2019/04/portainer-support-policy/
|
||||
|
||||
Do you need help or have a question? Come chat with us on Slack https://portainer.io/slack/
|
||||
|
||||
Before opening a new issue, make sure that we do not have any duplicates
|
||||
already open. You can ensure this by searching the issue list for this
|
||||
repository. If there is a duplicate, please close your issue and add a comment
|
||||
to the existing issue instead.
|
||||
|
||||
Also, be sure to check our FAQ and documentation first: https://documentation.portainer.io/
|
||||
-->
|
||||
|
||||
**Bug description**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Portainer Logs**
|
||||
Provide the logs of your Portainer container or Service.
|
||||
You can see how [here](https://documentation.portainer.io/r/portainer-logs)
|
||||
|
||||
**Steps to reproduce the issue:**
|
||||
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Technical details:**
|
||||
|
||||
- Portainer version:
|
||||
- Docker version (managed by Portainer):
|
||||
- Kubernetes version (managed by Portainer):
|
||||
- Platform (windows/linux):
|
||||
- Command used to start Portainer (`docker run -p 9443:9443 portainer/portainer`):
|
||||
- Browser:
|
||||
- Use Case (delete as appropriate): Using Portainer at Home, Using Portainer in a Commercial setup.
|
||||
- Have you reviewed our technical documentation and knowledge base? Yes/No
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
203
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
203
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -1,203 +0,0 @@
|
||||
name: Bug Report
|
||||
description: Create a report to help us improve.
|
||||
labels: kind/bug,bug/need-confirmation
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# Welcome!
|
||||
|
||||
The issue tracker is for reporting bugs. If you have an [idea for a new feature](https://github.com/orgs/portainer/discussions/categories/ideas) or a [general question about Portainer](https://github.com/orgs/portainer/discussions/categories/help) please post in our [GitHub Discussions](https://github.com/orgs/portainer/discussions).
|
||||
|
||||
You can also ask for help in our [community Slack channel](https://join.slack.com/t/portainer/shared_invite/zt-txh3ljab-52QHTyjCqbe5RibC2lcjKA).
|
||||
|
||||
Please note that we only provide support for current versions of Portainer. You can find a list of supported versions in our [lifecycle policy](https://docs.portainer.io/start/lifecycle).
|
||||
|
||||
**DO NOT FILE ISSUES FOR GENERAL SUPPORT QUESTIONS**.
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Before you start please confirm the following.
|
||||
options:
|
||||
- label: Yes, I've searched similar issues on [GitHub](https://github.com/portainer/portainer/issues).
|
||||
required: true
|
||||
- label: Yes, I've checked whether this issue is covered in the Portainer [documentation](https://docs.portainer.io).
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# About your issue
|
||||
|
||||
Tell us a bit about the issue you're having.
|
||||
|
||||
How to write a good bug report:
|
||||
|
||||
- Respect the issue template as much as possible.
|
||||
- Summarize the issue so that we understand what is going wrong.
|
||||
- Describe what you would have expected to have happened, and what actually happened instead.
|
||||
- Provide easy to follow steps to reproduce the issue.
|
||||
- Remain clear and concise.
|
||||
- Format your messages to help the reader focus on what matters and understand the structure of your message, use [Markdown syntax](https://help.github.com/articles/github-flavored-markdown).
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Problem Description
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: A clear and concise description of what you expected to happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: A clear and concise description of what actually happens.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Please be as detailed as possible when providing steps to reproduce.
|
||||
placeholder: |
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Portainer logs or screenshots
|
||||
description: Provide Portainer container logs or any screenshots related to the issue.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
# About your environment
|
||||
|
||||
Tell us a bit about your Portainer environment.
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Portainer version
|
||||
description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [updating first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
|
||||
multiple: false
|
||||
options:
|
||||
- '2.39.0'
|
||||
- '2.38.1'
|
||||
- '2.38.0'
|
||||
- '2.37.0'
|
||||
- '2.36.0'
|
||||
- '2.35.0'
|
||||
- '2.34.0'
|
||||
- '2.33.7'
|
||||
- '2.33.6'
|
||||
- '2.33.5'
|
||||
- '2.33.4'
|
||||
- '2.33.3'
|
||||
- '2.33.2'
|
||||
- '2.33.1'
|
||||
- '2.33.0'
|
||||
- '2.32.0'
|
||||
- '2.31.3'
|
||||
- '2.31.2'
|
||||
- '2.31.1'
|
||||
- '2.31.0'
|
||||
- '2.30.1'
|
||||
- '2.30.0'
|
||||
- '2.29.2'
|
||||
- '2.29.1'
|
||||
- '2.29.0'
|
||||
- '2.28.1'
|
||||
- '2.28.0'
|
||||
- '2.27.9'
|
||||
- '2.27.8'
|
||||
- '2.27.7'
|
||||
- '2.27.6'
|
||||
- '2.27.5'
|
||||
- '2.27.4'
|
||||
- '2.27.3'
|
||||
- '2.27.2'
|
||||
- '2.27.1'
|
||||
- '2.27.0'
|
||||
- '2.26.1'
|
||||
- '2.26.0'
|
||||
- '2.25.1'
|
||||
- '2.25.0'
|
||||
- '2.24.1'
|
||||
- '2.24.0'
|
||||
- '2.23.0'
|
||||
- '2.22.0'
|
||||
- '2.21.5'
|
||||
- '2.21.4'
|
||||
- '2.21.3'
|
||||
- '2.21.2'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Portainer Edition
|
||||
multiple: false
|
||||
options:
|
||||
- 'Business Edition (BE/EE) with 5NF / 3NF license'
|
||||
- 'Business Edition (BE/EE) with Home & Student license'
|
||||
- 'Business Edition (BE/EE) with Starter license'
|
||||
- 'Business Edition (BE/EE) with Professional or Enterprise license'
|
||||
- 'Community Edition (CE)'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: Platform and Version
|
||||
description: |
|
||||
Enter your container management platform (Docker | Swarm | Kubernetes) along with the version.
|
||||
Example: Docker 24.0.3 | Docker Swarm 24.0.3 | Kubernetes 1.26
|
||||
You can find our supported platforms [in our documentation](https://docs.portainer.io/start/requirements-and-prerequisites).
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: OS and Architecture
|
||||
description: |
|
||||
Enter your Operating System, Version and Architecture. Example: Ubuntu 22.04, AMD64 | Raspbian OS, ARM64
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: Browser
|
||||
description: |
|
||||
Enter your browser and version. Example: Google Chrome 114.0
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: What command did you use to deploy Portainer?
|
||||
description: |
|
||||
Example: `docker run -d -p 8000:8000 -p 9443:9443 --name portainer --restart=always -v /var/run/docker.sock:/var/run/docker.sock -v portainer_data:/data portainer/portainer-ce:latest`
|
||||
If you deployed Portainer using a compose file or manifest you can provide this here as well.
|
||||
render: bash
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional Information
|
||||
description: Any additional information about your environment, the bug, or anything else you think might be helpful.
|
||||
validations:
|
||||
required: false
|
||||
12
.github/ISSUE_TEMPLATE/config.yml
vendored
12
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,11 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Question
|
||||
url: https://github.com/orgs/portainer/discussions/new?category=help
|
||||
about: Ask us a question about Portainer usage or deployment.
|
||||
- name: Idea or Feature Request
|
||||
url: https://github.com/orgs/portainer/discussions/new?category=ideas
|
||||
about: Suggest an idea or feature/enhancement that should be added in Portainer.
|
||||
- name: Portainer Business Edition - Get 3 Nodes Free
|
||||
url: https://www.portainer.io/take-3
|
||||
about: Portainer Business Edition has more features, more support and you can now get 3 nodes free for as long as you want.
|
||||
- name: Portainer Business Edition - Get 5 nodes free
|
||||
url: https://portainer.io/pricing/take5
|
||||
about: Portainer Business Edition has more features, more support and you can now get 5 nodes free for as long as you want.
|
||||
|
||||
15
.github/workflows/label-conflcts.yaml
vendored
Normal file
15
.github/workflows/label-conflcts.yaml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/**'
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: mschilde/auto-label-merge-conflicts@master
|
||||
with:
|
||||
CONFLICT_LABEL_NAME: 'has conflicts'
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
MAX_RETRIES: 5
|
||||
WAIT_MS: 5000
|
||||
38
.github/workflows/lint.yml
vendored
Normal file
38
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
- release/*
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
- release/*
|
||||
|
||||
jobs:
|
||||
run-linters:
|
||||
name: Run linters
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: '14'
|
||||
cache: 'yarn'
|
||||
- run: yarn --frozen-lockfile
|
||||
|
||||
- name: Run linters
|
||||
uses: wearerequired/lint-action@v1
|
||||
with:
|
||||
eslint: true
|
||||
eslint_extensions: ts,tsx,js,jsx
|
||||
prettier: true
|
||||
prettier_dir: app/
|
||||
gofmt: true
|
||||
gofmt_dir: api/
|
||||
- name: Typecheck
|
||||
uses: icrawl/action-tsc@v1
|
||||
230
.github/workflows/nightly-security-scan.yml
vendored
Normal file
230
.github/workflows/nightly-security-scan.yml
vendored
Normal file
@@ -0,0 +1,230 @@
|
||||
name: Nightly Code Security Scan
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 8 * * *'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
client-dependencies:
|
||||
name: Client dependency check
|
||||
runs-on: ubuntu-latest
|
||||
if: >- # only run for develop branch
|
||||
github.ref == 'refs/heads/develop'
|
||||
outputs:
|
||||
js: ${{ steps.set-matrix.outputs.js_result }}
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
|
||||
- name: Run Snyk to check for vulnerabilities
|
||||
uses: snyk/actions/node@master
|
||||
continue-on-error: true # To make sure that artifact upload gets called
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
with:
|
||||
json: true
|
||||
|
||||
- name: Upload js security scan result as artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: js-security-scan-develop-result
|
||||
path: snyk.json
|
||||
|
||||
- name: Export scan result to html file
|
||||
run: |
|
||||
$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 summary -report-type=snyk -path="/data/snyk.json" -output-type=table -export -export-filename="/data/js-result")
|
||||
|
||||
- name: Upload js result html file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: html-js-result-${{github.run_id}}
|
||||
path: js-result.html
|
||||
|
||||
- name: Analyse the js result
|
||||
id: set-matrix
|
||||
run: |
|
||||
result=$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 summary -report-type=snyk -path="/data/snyk.json" -output-type=matrix)
|
||||
echo "::set-output name=js_result::${result}"
|
||||
|
||||
server-dependencies:
|
||||
name: Server dependency check
|
||||
runs-on: ubuntu-latest
|
||||
if: >- # only run for develop branch
|
||||
github.ref == 'refs/heads/develop'
|
||||
outputs:
|
||||
go: ${{ steps.set-matrix.outputs.go_result }}
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
|
||||
- name: Download go modules
|
||||
run: cd ./api && go get -t -v -d ./...
|
||||
|
||||
- name: Run Snyk to check for vulnerabilities
|
||||
uses: snyk/actions/golang@master
|
||||
continue-on-error: true # To make sure that artifact upload gets called
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
with:
|
||||
args: --file=./api/go.mod
|
||||
json: true
|
||||
|
||||
- name: Upload go security scan result as artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: go-security-scan-develop-result
|
||||
path: snyk.json
|
||||
|
||||
- name: Export scan result to html file
|
||||
run: |
|
||||
$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 summary -report-type=snyk -path="/data/snyk.json" -output-type=table -export -export-filename="/data/go-result")
|
||||
|
||||
- name: Upload go result html file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: html-go-result-${{github.run_id}}
|
||||
path: go-result.html
|
||||
|
||||
- name: Analyse the go result
|
||||
id: set-matrix
|
||||
run: |
|
||||
result=$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 summary -report-type=snyk -path="/data/snyk.json" -output-type=matrix)
|
||||
echo "::set-output name=go_result::${result}"
|
||||
|
||||
image-vulnerability:
|
||||
name: Build docker image and Image vulnerability check
|
||||
runs-on: ubuntu-latest
|
||||
if: >-
|
||||
github.ref == 'refs/heads/develop'
|
||||
outputs:
|
||||
image: ${{ steps.set-matrix.outputs.image_result }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@master
|
||||
|
||||
- name: Use golang 1.18
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.18'
|
||||
|
||||
- name: Use Node.js 12.x
|
||||
uses: actions/setup-node@v1
|
||||
with:
|
||||
node-version: 12.x
|
||||
|
||||
- name: Install packages and build
|
||||
run: yarn install && yarn build
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: build/linux/Dockerfile
|
||||
tags: trivy-portainer:${{ github.sha }}
|
||||
outputs: type=docker,dest=/tmp/trivy-portainer-image.tar
|
||||
|
||||
- name: Load docker image
|
||||
run: |
|
||||
docker load --input /tmp/trivy-portainer-image.tar
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: docker://docker.io/aquasec/trivy:latest
|
||||
continue-on-error: true
|
||||
with:
|
||||
args: image --ignore-unfixed=true --vuln-type="os,library" --exit-code=1 --format="json" --output="image-trivy.json" --no-progress trivy-portainer:${{ github.sha }}
|
||||
|
||||
- name: Upload image security scan result as artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: image-security-scan-develop-result
|
||||
path: image-trivy.json
|
||||
|
||||
- name: Export scan result to html file
|
||||
run: |
|
||||
$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 summary -report-type=trivy -path="/data/image-trivy.json" -output-type=table -export -export-filename="/data/image-result")
|
||||
|
||||
- name: Upload go result html file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: html-image-result-${{github.run_id}}
|
||||
path: image-result.html
|
||||
|
||||
- name: Analyse the trivy result
|
||||
id: set-matrix
|
||||
run: |
|
||||
result=$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 summary -report-type=trivy -path="/data/image-trivy.json" -output-type=matrix)
|
||||
echo "::set-output name=image_result::${result}"
|
||||
|
||||
result-analysis:
|
||||
name: Analyse scan result
|
||||
needs: [client-dependencies, server-dependencies, image-vulnerability]
|
||||
runs-on: ubuntu-latest
|
||||
if: >-
|
||||
github.ref == 'refs/heads/develop'
|
||||
strategy:
|
||||
matrix:
|
||||
js: ${{fromJson(needs.client-dependencies.outputs.js)}}
|
||||
go: ${{fromJson(needs.server-dependencies.outputs.go)}}
|
||||
image: ${{fromJson(needs.image-vulnerability.outputs.image)}}
|
||||
steps:
|
||||
- name: Display the results of js, go and image
|
||||
run: |
|
||||
echo ${{ matrix.js.status }}
|
||||
echo ${{ matrix.go.status }}
|
||||
echo ${{ matrix.image.status }}
|
||||
echo ${{ matrix.js.summary }}
|
||||
echo ${{ matrix.go.summary }}
|
||||
echo ${{ matrix.image.summary }}
|
||||
|
||||
- name: Send Slack message
|
||||
if: >-
|
||||
matrix.js.status == 'failure' ||
|
||||
matrix.go.status == 'failure' ||
|
||||
matrix.image.status == 'failure'
|
||||
uses: slackapi/slack-github-action@v1.18.0
|
||||
with:
|
||||
payload: |
|
||||
{
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "Code Scanning Result (*${{ github.repository }}*)\n*<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|GitHub Actions Workflow URL>*"
|
||||
}
|
||||
}
|
||||
],
|
||||
"attachments": [
|
||||
{
|
||||
"color": "#FF0000",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "*JS dependency check*: *${{ matrix.js.status }}*\n${{ matrix.js.summary }}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "*Go dependency check*: *${{ matrix.go.status }}*\n${{ matrix.go.summary }}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "*Image vulnerability check*: *${{ matrix.image.status }}*\n${{ matrix.image.summary }}\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SECURITY_SLACK_WEBHOOK_URL }}
|
||||
SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK
|
||||
233
.github/workflows/pr-security.yml
vendored
Normal file
233
.github/workflows/pr-security.yml
vendored
Normal file
@@ -0,0 +1,233 @@
|
||||
name: PR Code Security Scan
|
||||
|
||||
on:
|
||||
pull_request_review:
|
||||
types:
|
||||
- submitted
|
||||
- edited
|
||||
paths:
|
||||
- 'package.json'
|
||||
- 'api/go.mod'
|
||||
- 'gruntfile.js'
|
||||
- 'build/linux/Dockerfile'
|
||||
- 'build/linux/alpine.Dockerfile'
|
||||
- 'build/windows/Dockerfile'
|
||||
|
||||
jobs:
|
||||
client-dependencies:
|
||||
name: Client dependency check
|
||||
runs-on: ubuntu-latest
|
||||
if: >-
|
||||
github.event.pull_request &&
|
||||
github.event.review.body == '/scan'
|
||||
outputs:
|
||||
jsdiff: ${{ steps.set-diff-matrix.outputs.js_diff_result }}
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
|
||||
- name: Run Snyk to check for vulnerabilities
|
||||
uses: snyk/actions/node@master
|
||||
continue-on-error: true # To make sure that artifact upload gets called
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
with:
|
||||
json: true
|
||||
|
||||
- name: Upload js security scan result as artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: js-security-scan-feat-result
|
||||
path: snyk.json
|
||||
|
||||
- name: Download artifacts from develop branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
mv ./snyk.json ./js-snyk-feature.json
|
||||
(gh run download -n js-security-scan-develop-result -R ${{ github.repository }} 2>&1 >/dev/null) || :
|
||||
if [[ -e ./snyk.json ]]; then
|
||||
mv ./snyk.json ./js-snyk-develop.json
|
||||
else
|
||||
echo "null" > ./js-snyk-develop.json
|
||||
fi
|
||||
|
||||
- name: Export scan result to html file
|
||||
run: |
|
||||
$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 diff -report-type=snyk -path="/data/js-snyk-feature.json" -compare-to="/data/js-snyk-develop.json" -output-type=table -export -export-filename="/data/js-result")
|
||||
|
||||
- name: Upload js result html file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: html-js-result-compare-to-develop-${{github.run_id}}
|
||||
path: js-result.html
|
||||
|
||||
- name: Analyse the js diff result
|
||||
id: set-diff-matrix
|
||||
run: |
|
||||
result=$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 diff -report-type=snyk -path="/data/js-snyk-feature.json" -compare-to="./data/js-snyk-develop.json" -output-type=matrix)
|
||||
echo "::set-output name=js_diff_result::${result}"
|
||||
|
||||
server-dependencies:
|
||||
name: Server dependency check
|
||||
runs-on: ubuntu-latest
|
||||
if: >-
|
||||
github.event.pull_request &&
|
||||
github.event.review.body == '/scan'
|
||||
outputs:
|
||||
godiff: ${{ steps.set-diff-matrix.outputs.go_diff_result }}
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
|
||||
- name: Download go modules
|
||||
run: cd ./api && go get -t -v -d ./...
|
||||
|
||||
- name: Run Snyk to check for vulnerabilities
|
||||
uses: snyk/actions/golang@master
|
||||
continue-on-error: true # To make sure that artifact upload gets called
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
with:
|
||||
args: --file=./api/go.mod
|
||||
json: true
|
||||
|
||||
- name: Upload go security scan result as artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: go-security-scan-feature-result
|
||||
path: snyk.json
|
||||
|
||||
- name: Download artifacts from develop branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
mv ./snyk.json ./go-snyk-feature.json
|
||||
(gh run download -n go-security-scan-develop-result -R ${{ github.repository }} 2>&1 >/dev/null) || :
|
||||
if [[ -e ./snyk.json ]]; then
|
||||
mv ./snyk.json ./go-snyk-develop.json
|
||||
else
|
||||
echo "null" > ./go-snyk-develop.json
|
||||
fi
|
||||
|
||||
- name: Export scan result to html file
|
||||
run: |
|
||||
$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 diff -report-type=snyk -path="/data/go-snyk-feature.json" -compare-to="/data/go-snyk-develop.json" -output-type=table -export -export-filename="/data/go-result")
|
||||
|
||||
- name: Upload go result html file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: html-go-result-compare-to-develop-${{github.run_id}}
|
||||
path: go-result.html
|
||||
|
||||
- name: Analyse the go diff result
|
||||
id: set-diff-matrix
|
||||
run: |
|
||||
result=$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 diff -report-type=snyk -path="/data/go-snyk-feature.json" -compare-to="/data/go-snyk-develop.json" -output-type=matrix)
|
||||
echo "::set-output name=go_diff_result::${result}"
|
||||
|
||||
image-vulnerability:
|
||||
name: Build docker image and Image vulnerability check
|
||||
runs-on: ubuntu-latest
|
||||
if: >-
|
||||
github.event.pull_request &&
|
||||
github.event.review.body == '/scan'
|
||||
outputs:
|
||||
imagediff: ${{ steps.set-diff-matrix.outputs.image_diff_result }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@master
|
||||
|
||||
- name: Use golang 1.18
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.18'
|
||||
|
||||
- name: Use Node.js 12.x
|
||||
uses: actions/setup-node@v1
|
||||
with:
|
||||
node-version: 12.x
|
||||
|
||||
- name: Install packages and build
|
||||
run: yarn install && yarn build
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: build/linux/Dockerfile
|
||||
tags: trivy-portainer:${{ github.sha }}
|
||||
outputs: type=docker,dest=/tmp/trivy-portainer-image.tar
|
||||
|
||||
- name: Load docker image
|
||||
run: |
|
||||
docker load --input /tmp/trivy-portainer-image.tar
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: docker://docker.io/aquasec/trivy:latest
|
||||
continue-on-error: true
|
||||
with:
|
||||
args: image --ignore-unfixed=true --vuln-type="os,library" --exit-code=1 --format="json" --output="image-trivy.json" --no-progress trivy-portainer:${{ github.sha }}
|
||||
|
||||
- name: Upload image security scan result as artifact
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: image-security-scan-feature-result
|
||||
path: image-trivy.json
|
||||
|
||||
- name: Download artifacts from develop branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
mv ./image-trivy.json ./image-trivy-feature.json
|
||||
(gh run download -n image-security-scan-develop-result -R ${{ github.repository }} 2>&1 >/dev/null) || :
|
||||
if [[ -e ./image-trivy.json ]]; then
|
||||
mv ./image-trivy.json ./image-trivy-develop.json
|
||||
else
|
||||
echo "null" > ./image-trivy-develop.json
|
||||
fi
|
||||
|
||||
- name: Export scan result to html file
|
||||
run: |
|
||||
$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 diff -report-type=trivy -path="/data/image-trivy-feature.json" -compare-to="/data/image-trivy-develop.json" -output-type=table -export -export-filename="/data/image-result")
|
||||
|
||||
- name: Upload image result html file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: html-image-result-compare-to-develop-${{github.run_id}}
|
||||
path: image-result.html
|
||||
|
||||
- name: Analyse the image diff result
|
||||
id: set-diff-matrix
|
||||
run: |
|
||||
result=$(docker run --rm -v ${{ github.workspace }}:/data oscarzhou/scan-report:0.1.8 diff -report-type=trivy -path="/data/image-trivy-feature.json" -compare-to="./data/image-trivy-develop.json" -output-type=matrix)
|
||||
echo "::set-output name=image_diff_result::${result}"
|
||||
|
||||
result-analysis:
|
||||
name: Analyse scan result compared to develop
|
||||
needs: [client-dependencies, server-dependencies, image-vulnerability]
|
||||
runs-on: ubuntu-latest
|
||||
if: >-
|
||||
github.event.pull_request &&
|
||||
github.event.review.body == '/scan'
|
||||
strategy:
|
||||
matrix:
|
||||
jsdiff: ${{fromJson(needs.client-dependencies.outputs.jsdiff)}}
|
||||
godiff: ${{fromJson(needs.server-dependencies.outputs.godiff)}}
|
||||
imagediff: ${{fromJson(needs.image-vulnerability.outputs.imagediff)}}
|
||||
steps:
|
||||
|
||||
- name: Check job status of diff result
|
||||
if: >-
|
||||
matrix.jsdiff.status == 'failure' ||
|
||||
matrix.godiff.status == 'failure' ||
|
||||
matrix.imagediff.status == 'failure'
|
||||
run: |
|
||||
echo ${{ matrix.jsdiff.status }}
|
||||
echo ${{ matrix.godiff.status }}
|
||||
echo ${{ matrix.imagediff.status }}
|
||||
echo ${{ matrix.jsdiff.summary }}
|
||||
echo ${{ matrix.godiff.summary }}
|
||||
echo ${{ matrix.imagediff.summary }}
|
||||
exit 1
|
||||
19
.github/workflows/rebase.yml
vendored
Normal file
19
.github/workflows/rebase.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Automatic Rebase
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
jobs:
|
||||
rebase:
|
||||
name: Rebase
|
||||
if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/rebase')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout the latest code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
fetch-depth: 0 # otherwise, you will fail to push refs to dest repo
|
||||
- name: Automatic Rebase
|
||||
uses: cirrus-actions/rebase@1.4
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
27
.github/workflows/stale.yml
vendored
Normal file
27
.github/workflows/stale.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Close Stale Issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 12 * * *'
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v4.0.0
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Issue Config
|
||||
days-before-issue-stale: 60
|
||||
days-before-issue-close: 7
|
||||
stale-issue-label: 'status/stale'
|
||||
exempt-all-issue-milestones: true # Do not stale issues in a milestone
|
||||
exempt-issue-labels: kind/enhancement, kind/style, kind/workaround, kind/refactor, bug/need-confirmation, bug/confirmed, status/discuss
|
||||
stale-issue-message: 'This issue has been marked as stale as it has not had recent activity, it will be closed if no further activity occurs in the next 7 days. If you believe that it has been incorrectly labelled as stale, leave a comment and the label will be removed.'
|
||||
close-issue-message: 'Since no further activity has appeared on this issue it will be closed. If you believe that it has been incorrectly closed, leave a comment mentioning `portainer/support` and one of our staff will then review the issue. Note - If it is an old bug report, make sure that it is reproduceable in the latest version of Portainer as it may have already been fixed.'
|
||||
|
||||
# Pull Request Config
|
||||
days-before-pr-stale: -1 # Do not stale pull request
|
||||
days-before-pr-close: -1 # Do not close pull request
|
||||
29
.github/workflows/test.yaml
vendored
Normal file
29
.github/workflows/test.yaml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Test
|
||||
on: push
|
||||
jobs:
|
||||
test-client:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: '14'
|
||||
cache: 'yarn'
|
||||
- run: yarn --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: yarn test:client
|
||||
# test-server:
|
||||
# runs-on: ubuntu-latest
|
||||
# env:
|
||||
# GOPRIVATE: "github.com/portainer"
|
||||
# steps:
|
||||
# - uses: actions/checkout@v3
|
||||
# - uses: actions/setup-go@v3
|
||||
# with:
|
||||
# go-version: '1.18'
|
||||
# - name: Run tests
|
||||
# run: |
|
||||
# cd api
|
||||
# go test ./...
|
||||
53
.github/workflows/validate-openapi-spec.yml
vendored
Normal file
53
.github/workflows/validate-openapi-spec.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
name: Validate
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
- develop
|
||||
- 'release/*'
|
||||
|
||||
jobs:
|
||||
openapi-spec:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v2
|
||||
- name: Setup Node v14
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 14
|
||||
|
||||
# https://github.com/actions/cache/blob/main/examples.md#node---yarn
|
||||
- name: Get yarn cache directory path
|
||||
id: yarn-cache-dir-path
|
||||
run: echo "::set-output name=dir::$(yarn cache dir)"
|
||||
|
||||
- uses: actions/cache@v2
|
||||
id: yarn-cache # use this to check for `cache-hit` (`steps.yarn-cache.outputs.cache-hit != 'true'`)
|
||||
with:
|
||||
path: ${{ steps.yarn-cache-dir-path.outputs.dir }}
|
||||
key: ${{ runner.os }}-yarn-${{ hashFiles('**/yarn.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-yarn-
|
||||
|
||||
- name: Setup Go v1.17.3
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: '^1.17.3'
|
||||
|
||||
- name: Prebuild docs
|
||||
run: yarn prebuild:docs
|
||||
|
||||
- name: Build OpenAPI 2.0 Spec
|
||||
run: yarn build:docs
|
||||
|
||||
# Install dependencies globally to bypass installing all frontend deps
|
||||
- name: Install swagger2openapi and swagger-cli
|
||||
run: yarn global add swagger2openapi @apidevtools/swagger-cli
|
||||
|
||||
# OpenAPI2.0 does not support multiple body params (which we utilise in some of our handlers).
|
||||
# OAS3.0 however does support multiple body params - hence its best to convert the generated OAS 2.0
|
||||
# to OAS 3.0 and validate the output of generated OAS 3.0 instead.
|
||||
- name: Convert OpenAPI 2.0 to OpenAPI 3.0 and validate spec
|
||||
run: yarn validate:docs
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -7,16 +7,11 @@ storybook-static
|
||||
.tmp
|
||||
**/.vscode/settings.json
|
||||
**/.vscode/tasks.json
|
||||
.vscode
|
||||
*.DS_Store
|
||||
|
||||
.eslintcache
|
||||
__debug_bin*
|
||||
__debug_bin
|
||||
|
||||
api/docs
|
||||
.idea
|
||||
.env
|
||||
go.work.sum
|
||||
|
||||
.vitest
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- forbidigo
|
||||
settings:
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^dataservices.DataStore.(EdgeGroup|EdgeJob|EdgeStack|EndpointRelation|Endpoint|GitCredential|Registry|ResourceControl|Role|Settings|Snapshot|SSLSettings|Stack|Tag|User)$
|
||||
msg: Use a transaction instead
|
||||
analyze-types: true
|
||||
exclusions:
|
||||
rules:
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- forbidigo
|
||||
108
.golangci.yaml
108
.golangci.yaml
@@ -1,108 +0,0 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
allow-parallel-runners: true
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- bodyclose
|
||||
- copyloopvar
|
||||
- depguard
|
||||
- errcheck
|
||||
- errorlint
|
||||
- forbidigo
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
- perfsprint
|
||||
- staticcheck
|
||||
- unused
|
||||
- mirror
|
||||
- durationcheck
|
||||
- errorlint
|
||||
- govet
|
||||
- usetesting
|
||||
- zerologlint
|
||||
- testifylint
|
||||
- modernize
|
||||
- unconvert
|
||||
- unused
|
||||
- zerologlint
|
||||
- exptostd
|
||||
settings:
|
||||
staticcheck:
|
||||
checks: ["all", "-ST1003", "-ST1005", "-ST1016", "-SA1019", "-QF1003"]
|
||||
depguard:
|
||||
rules:
|
||||
main:
|
||||
files:
|
||||
- '!**/*_test.go'
|
||||
- '!**/base.go'
|
||||
- '!**/base_tx.go'
|
||||
deny:
|
||||
- pkg: encoding/json
|
||||
desc: use github.com/segmentio/encoding/json
|
||||
- pkg: golang.org/x/exp
|
||||
desc: exp is not allowed
|
||||
- pkg: github.com/portainer/libcrypto
|
||||
desc: use github.com/portainer/portainer/pkg/libcrypto
|
||||
- pkg: github.com/portainer/libhttp
|
||||
desc: use github.com/portainer/portainer/pkg/libhttp
|
||||
- pkg: golang.org/x/crypto
|
||||
desc: golang.org/x/crypto is not allowed because of FIPS mode
|
||||
- pkg: github.com/ProtonMail/go-crypto/openpgp
|
||||
desc: github.com/ProtonMail/go-crypto/openpgp is not allowed because of FIPS mode
|
||||
- pkg: github.com/cosi-project/runtime
|
||||
desc: github.com/cosi-project/runtime is not allowed because of FIPS mode
|
||||
- pkg: gopkg.in/yaml.v2
|
||||
desc: use go.yaml.in/yaml/v3 instead
|
||||
- pkg: gopkg.in/yaml.v3
|
||||
desc: use go.yaml.in/yaml/v3 instead
|
||||
- pkg: github.com/golang-jwt/jwt/v4
|
||||
desc: use github.com/golang-jwt/jwt/v5 instead
|
||||
- pkg: github.com/mitchellh/mapstructure
|
||||
desc: use github.com/go-viper/mapstructure/v2 instead
|
||||
- pkg: gopkg.in/alecthomas/kingpin.v2
|
||||
desc: use github.com/alecthomas/kingpin/v2 instead
|
||||
- pkg: github.com/jcmturner/gokrb5$
|
||||
desc: use github.com/jcmturner/gokrb5/v8 instead
|
||||
- pkg: github.com/gofrs/uuid
|
||||
desc: use github.com/google/uuid
|
||||
- pkg: github.com/Masterminds/semver$
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
- pkg: github.com/blang/semver
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
- pkg: github.com/coreos/go-semver
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
- pkg: github.com/hashicorp/go-version
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^tls\.Config$
|
||||
msg: Use crypto.CreateTLSConfiguration() instead
|
||||
- pattern: ^tls\.Config\.(InsecureSkipVerify|MinVersion|MaxVersion|CipherSuites|CurvePreferences)$
|
||||
msg: Do not set this field directly, use crypto.CreateTLSConfiguration() instead
|
||||
- pattern: ^object\.(Commit|Tag)\.Verify$
|
||||
msg: "Not allowed because of FIPS mode"
|
||||
- pattern: ^(types\.SystemContext\.)?(DockerDaemonInsecureSkipTLSVerify|DockerInsecureSkipTLSVerify|OCIInsecureSkipTLSVerify)$
|
||||
msg: "Not allowed because of FIPS mode"
|
||||
analyze-types: true
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
@@ -1,4 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
. "$(dirname -- "$0")/_/husky.sh"
|
||||
|
||||
cd $(dirname -- "$0") && pnpm lint-staged
|
||||
@@ -1,3 +1,2 @@
|
||||
dist
|
||||
api/datastore/test_data
|
||||
coverage
|
||||
api/datastore/test_data
|
||||
12
.prettierrc
12
.prettierrc
@@ -2,24 +2,18 @@
|
||||
"printWidth": 180,
|
||||
"singleQuote": true,
|
||||
"htmlWhitespaceSensitivity": "strict",
|
||||
"trailingComma": "es5",
|
||||
"overrides": [
|
||||
{
|
||||
"files": [
|
||||
"*.html"
|
||||
],
|
||||
"files": ["*.html"],
|
||||
"options": {
|
||||
"parser": "angular"
|
||||
}
|
||||
},
|
||||
{
|
||||
"files": [
|
||||
"*.{j,t}sx",
|
||||
"*.ts"
|
||||
],
|
||||
"files": ["*.{j,t}sx", "*.ts"],
|
||||
"options": {
|
||||
"printWidth": 80
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
38
.storybook/main.js
Normal file
38
.storybook/main.js
Normal file
@@ -0,0 +1,38 @@
|
||||
const TsconfigPathsPlugin = require('tsconfig-paths-webpack-plugin');
|
||||
|
||||
module.exports = {
|
||||
stories: ['../app/**/*.stories.mdx', '../app/**/*.stories.@(ts|tsx)'],
|
||||
addons: [
|
||||
'@storybook/addon-links',
|
||||
'@storybook/addon-essentials',
|
||||
{
|
||||
name: '@storybook/addon-postcss',
|
||||
options: {
|
||||
cssLoaderOptions: {
|
||||
importLoaders: 1,
|
||||
modules: {
|
||||
localIdentName: '[path][name]__[local]',
|
||||
auto: true,
|
||||
exportLocalsConvention: 'camelCaseOnly',
|
||||
},
|
||||
},
|
||||
postcssLoaderOptions: {
|
||||
implementation: require('postcss'),
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
webpackFinal: (config) => {
|
||||
config.resolve.plugins = [
|
||||
...(config.resolve.plugins || []),
|
||||
new TsconfigPathsPlugin({
|
||||
extensions: config.resolve.extensions,
|
||||
}),
|
||||
];
|
||||
return config;
|
||||
},
|
||||
core: {
|
||||
builder: 'webpack5',
|
||||
},
|
||||
staticDirs: ['./public'],
|
||||
};
|
||||
@@ -1,111 +0,0 @@
|
||||
import { StorybookConfig } from '@storybook/react-webpack5';
|
||||
|
||||
import TsconfigPathsPlugin from 'tsconfig-paths-webpack-plugin';
|
||||
import { Configuration } from 'webpack';
|
||||
import postcss from 'postcss';
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../app/**/*.stories.@(ts|tsx)'],
|
||||
addons: [
|
||||
'@storybook/addon-links',
|
||||
'@storybook/addon-essentials',
|
||||
'@storybook/addon-webpack5-compiler-swc',
|
||||
'@chromatic-com/storybook',
|
||||
{
|
||||
name: '@storybook/addon-styling-webpack',
|
||||
|
||||
options: {
|
||||
rules: [
|
||||
{
|
||||
test: /\.css$/,
|
||||
sideEffects: true,
|
||||
use: [
|
||||
require.resolve('style-loader'),
|
||||
{
|
||||
loader: require.resolve('css-loader'),
|
||||
options: {
|
||||
importLoaders: 1,
|
||||
modules: {
|
||||
localIdentName: '[path][name]__[local]',
|
||||
auto: true,
|
||||
exportLocalsConvention: 'camelCaseOnly',
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
loader: require.resolve('postcss-loader'),
|
||||
options: {
|
||||
implementation: postcss,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
webpackFinal: (config) => {
|
||||
const rules = config?.module?.rules || [];
|
||||
|
||||
const imageRule = rules.find((rule) => {
|
||||
const test = (rule as { test: RegExp }).test;
|
||||
|
||||
if (!test) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return test.test('.svg');
|
||||
}) as { [key: string]: any };
|
||||
|
||||
imageRule.exclude = /\.svg$/;
|
||||
|
||||
rules.unshift({
|
||||
test: /\.svg$/i,
|
||||
type: 'asset',
|
||||
resourceQuery: {
|
||||
not: [/c/],
|
||||
}, // exclude react component if *.svg?url
|
||||
});
|
||||
|
||||
rules.unshift({
|
||||
test: /\.svg$/i,
|
||||
issuer: /\.(js|ts)(x)?$/,
|
||||
resourceQuery: /c/,
|
||||
// *.svg?c
|
||||
use: [
|
||||
{
|
||||
loader: '@svgr/webpack',
|
||||
options: {
|
||||
icon: true,
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
return {
|
||||
...config,
|
||||
resolve: {
|
||||
...config.resolve,
|
||||
plugins: [
|
||||
...(config.resolve?.plugins || []),
|
||||
new TsconfigPathsPlugin({
|
||||
extensions: config.resolve?.extensions,
|
||||
}),
|
||||
],
|
||||
},
|
||||
module: {
|
||||
...config.module,
|
||||
rules,
|
||||
},
|
||||
} satisfies Configuration;
|
||||
},
|
||||
staticDirs: ['./public'],
|
||||
typescript: {
|
||||
reactDocgen: 'react-docgen-typescript',
|
||||
},
|
||||
framework: {
|
||||
name: '@storybook/react-webpack5',
|
||||
options: {},
|
||||
},
|
||||
};
|
||||
|
||||
export default config;
|
||||
48
.storybook/preview.js
Normal file
48
.storybook/preview.js
Normal file
@@ -0,0 +1,48 @@
|
||||
import '../app/assets/css';
|
||||
|
||||
import { pushStateLocationPlugin, UIRouter } from '@uirouter/react';
|
||||
import { initialize as initMSW, mswDecorator } from 'msw-storybook-addon';
|
||||
import { handlers } from '@/setup-tests/server-handlers';
|
||||
import { QueryClient, QueryClientProvider } from 'react-query';
|
||||
|
||||
// Initialize MSW
|
||||
initMSW({
|
||||
onUnhandledRequest: ({ method, url }) => {
|
||||
if (url.pathname.startsWith('/api')) {
|
||||
console.error(`Unhandled ${method} request to ${url}.
|
||||
|
||||
This exception has been only logged in the console, however, it's strongly recommended to resolve this error as you don't want unmocked data in Storybook stories.
|
||||
|
||||
If you wish to mock an error response, please refer to this guide: https://mswjs.io/docs/recipes/mocking-error-responses
|
||||
`);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
export const parameters = {
|
||||
actions: { argTypesRegex: '^on[A-Z].*' },
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/,
|
||||
},
|
||||
},
|
||||
msw: {
|
||||
handlers,
|
||||
},
|
||||
};
|
||||
|
||||
const testQueryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
|
||||
export const decorators = [
|
||||
(Story) => (
|
||||
<QueryClientProvider client={testQueryClient}>
|
||||
<UIRouter plugins={[pushStateLocationPlugin]}>
|
||||
<Story />
|
||||
</UIRouter>
|
||||
</QueryClientProvider>
|
||||
),
|
||||
mswDecorator,
|
||||
];
|
||||
@@ -1,50 +0,0 @@
|
||||
import '../app/assets/css';
|
||||
import { pushStateLocationPlugin, UIRouter } from '@uirouter/react';
|
||||
import { initialize as initMSW, mswLoader } from 'msw-storybook-addon';
|
||||
import { handlers } from '../app/setup-tests/server-handlers';
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
|
||||
import { Preview } from '@storybook/react';
|
||||
|
||||
initMSW(
|
||||
{
|
||||
onUnhandledRequest: ({ method, url }) => {
|
||||
if (url.startsWith('/api')) {
|
||||
console.error(`Unhandled ${method} request to ${url}.
|
||||
|
||||
This exception has been only logged in the console, however, it's strongly recommended to resolve this error as you don't want unmocked data in Storybook stories.
|
||||
|
||||
If you wish to mock an error response, please refer to this guide: https://mswjs.io/docs/recipes/mocking-error-responses
|
||||
`);
|
||||
}
|
||||
},
|
||||
},
|
||||
handlers
|
||||
);
|
||||
|
||||
const testQueryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
|
||||
const preview: Preview = {
|
||||
decorators: (Story) => (
|
||||
<QueryClientProvider client={testQueryClient}>
|
||||
<UIRouter plugins={[pushStateLocationPlugin]}>
|
||||
<Story />
|
||||
</UIRouter>
|
||||
</QueryClientProvider>
|
||||
),
|
||||
loaders: [mswLoader],
|
||||
parameters: {
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/,
|
||||
},
|
||||
},
|
||||
msw: {
|
||||
handlers,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export default preview;
|
||||
@@ -2,22 +2,22 @@
|
||||
/* tslint:disable */
|
||||
|
||||
/**
|
||||
* Mock Service Worker (2.0.11).
|
||||
* Mock Service Worker (0.36.3).
|
||||
* @see https://github.com/mswjs/msw
|
||||
* - Please do NOT modify this file.
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const INTEGRITY_CHECKSUM = 'c5f7f8e188b673ea4e677df7ea3c5a39';
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse');
|
||||
const INTEGRITY_CHECKSUM = '02f4ad4a2797f85668baf196e553d929';
|
||||
const bypassHeaderName = 'x-msw-bypass';
|
||||
const activeClientIds = new Set();
|
||||
|
||||
self.addEventListener('install', function () {
|
||||
self.skipWaiting();
|
||||
return self.skipWaiting();
|
||||
});
|
||||
|
||||
self.addEventListener('activate', function (event) {
|
||||
event.waitUntil(self.clients.claim());
|
||||
self.addEventListener('activate', async function (event) {
|
||||
return self.clients.claim();
|
||||
});
|
||||
|
||||
self.addEventListener('message', async function (event) {
|
||||
@@ -33,9 +33,7 @@ self.addEventListener('message', async function (event) {
|
||||
return;
|
||||
}
|
||||
|
||||
const allClients = await self.clients.matchAll({
|
||||
type: 'window',
|
||||
});
|
||||
const allClients = await self.clients.matchAll();
|
||||
|
||||
switch (event.data) {
|
||||
case 'KEEPALIVE_REQUEST': {
|
||||
@@ -85,8 +83,165 @@ self.addEventListener('message', async function (event) {
|
||||
}
|
||||
});
|
||||
|
||||
// Resolve the "main" client for the given event.
|
||||
// Client that issues a request doesn't necessarily equal the client
|
||||
// that registered the worker. It's with the latter the worker should
|
||||
// communicate with during the response resolving phase.
|
||||
async function resolveMainClient(event) {
|
||||
const client = await self.clients.get(event.clientId);
|
||||
|
||||
if (client.frameType === 'top-level') {
|
||||
return client;
|
||||
}
|
||||
|
||||
const allClients = await self.clients.matchAll();
|
||||
|
||||
return allClients
|
||||
.filter((client) => {
|
||||
// Get only those clients that are currently visible.
|
||||
return client.visibilityState === 'visible';
|
||||
})
|
||||
.find((client) => {
|
||||
// Find the client ID that's recorded in the
|
||||
// set of clients that have registered the worker.
|
||||
return activeClientIds.has(client.id);
|
||||
});
|
||||
}
|
||||
|
||||
async function handleRequest(event, requestId) {
|
||||
const client = await resolveMainClient(event);
|
||||
const response = await getResponse(event, client, requestId);
|
||||
|
||||
// Send back the response clone for the "response:*" life-cycle events.
|
||||
// Ensure MSW is active and ready to handle the message, otherwise
|
||||
// this message will pend indefinitely.
|
||||
if (client && activeClientIds.has(client.id)) {
|
||||
(async function () {
|
||||
const clonedResponse = response.clone();
|
||||
sendToClient(client, {
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
requestId,
|
||||
type: clonedResponse.type,
|
||||
ok: clonedResponse.ok,
|
||||
status: clonedResponse.status,
|
||||
statusText: clonedResponse.statusText,
|
||||
body: clonedResponse.body === null ? null : await clonedResponse.text(),
|
||||
headers: serializeHeaders(clonedResponse.headers),
|
||||
redirected: clonedResponse.redirected,
|
||||
},
|
||||
});
|
||||
})();
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
async function getResponse(event, client, requestId) {
|
||||
const { request } = event;
|
||||
const requestClone = request.clone();
|
||||
const getOriginalResponse = () => fetch(requestClone);
|
||||
|
||||
// Bypass mocking when the request client is not active.
|
||||
if (!client) {
|
||||
return getOriginalResponse();
|
||||
}
|
||||
|
||||
// Bypass initial page load requests (i.e. static assets).
|
||||
// The absence of the immediate/parent client in the map of the active clients
|
||||
// means that MSW hasn't dispatched the "MOCK_ACTIVATE" event yet
|
||||
// and is not ready to handle requests.
|
||||
if (!activeClientIds.has(client.id)) {
|
||||
return await getOriginalResponse();
|
||||
}
|
||||
|
||||
// Bypass requests with the explicit bypass header
|
||||
if (requestClone.headers.get(bypassHeaderName) === 'true') {
|
||||
const cleanRequestHeaders = serializeHeaders(requestClone.headers);
|
||||
|
||||
// Remove the bypass header to comply with the CORS preflight check.
|
||||
delete cleanRequestHeaders[bypassHeaderName];
|
||||
|
||||
const originalRequest = new Request(requestClone, {
|
||||
headers: new Headers(cleanRequestHeaders),
|
||||
});
|
||||
|
||||
return fetch(originalRequest);
|
||||
}
|
||||
|
||||
// Send the request to the client-side MSW.
|
||||
const reqHeaders = serializeHeaders(request.headers);
|
||||
const body = await request.text();
|
||||
|
||||
const clientMessage = await sendToClient(client, {
|
||||
type: 'REQUEST',
|
||||
payload: {
|
||||
id: requestId,
|
||||
url: request.url,
|
||||
method: request.method,
|
||||
headers: reqHeaders,
|
||||
cache: request.cache,
|
||||
mode: request.mode,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body,
|
||||
bodyUsed: request.bodyUsed,
|
||||
keepalive: request.keepalive,
|
||||
},
|
||||
});
|
||||
|
||||
switch (clientMessage.type) {
|
||||
case 'MOCK_SUCCESS': {
|
||||
return delayPromise(() => respondWithMock(clientMessage), clientMessage.payload.delay);
|
||||
}
|
||||
|
||||
case 'MOCK_NOT_FOUND': {
|
||||
return getOriginalResponse();
|
||||
}
|
||||
|
||||
case 'NETWORK_ERROR': {
|
||||
const { name, message } = clientMessage.payload;
|
||||
const networkError = new Error(message);
|
||||
networkError.name = name;
|
||||
|
||||
// Rejecting a request Promise emulates a network error.
|
||||
throw networkError;
|
||||
}
|
||||
|
||||
case 'INTERNAL_ERROR': {
|
||||
const parsedBody = JSON.parse(clientMessage.payload.body);
|
||||
|
||||
console.error(
|
||||
`\
|
||||
[MSW] Uncaught exception in the request handler for "%s %s":
|
||||
|
||||
${parsedBody.location}
|
||||
|
||||
This exception has been gracefully handled as a 500 response, however, it's strongly recommended to resolve this error, as it indicates a mistake in your code. If you wish to mock an error response, please see this guide: https://mswjs.io/docs/recipes/mocking-error-responses\
|
||||
`,
|
||||
request.method,
|
||||
request.url
|
||||
);
|
||||
|
||||
return respondWithMock(clientMessage);
|
||||
}
|
||||
}
|
||||
|
||||
return getOriginalResponse();
|
||||
}
|
||||
|
||||
self.addEventListener('fetch', function (event) {
|
||||
const { request } = event;
|
||||
const accept = request.headers.get('accept') || '';
|
||||
|
||||
// Bypass server-sent events.
|
||||
if (accept.includes('text/event-stream')) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Bypass navigation requests.
|
||||
if (request.mode === 'navigate') {
|
||||
@@ -106,149 +261,36 @@ self.addEventListener('fetch', function (event) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate unique request ID.
|
||||
const requestId = crypto.randomUUID();
|
||||
event.respondWith(handleRequest(event, requestId));
|
||||
const requestId = uuidv4();
|
||||
|
||||
return event.respondWith(
|
||||
handleRequest(event, requestId).catch((error) => {
|
||||
if (error.name === 'NetworkError') {
|
||||
console.warn('[MSW] Successfully emulated a network error for the "%s %s" request.', request.method, request.url);
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, any exception indicates an issue with the original request/response.
|
||||
console.error(
|
||||
`\
|
||||
[MSW] Caught an exception from the "%s %s" request (%s). This is probably not a problem with Mock Service Worker. There is likely an additional logging output above.`,
|
||||
request.method,
|
||||
request.url,
|
||||
`${error.name}: ${error.message}`
|
||||
);
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
async function handleRequest(event, requestId) {
|
||||
const client = await resolveMainClient(event);
|
||||
const response = await getResponse(event, client, requestId);
|
||||
|
||||
// Send back the response clone for the "response:*" life-cycle events.
|
||||
// Ensure MSW is active and ready to handle the message, otherwise
|
||||
// this message will pend indefinitely.
|
||||
if (client && activeClientIds.has(client.id)) {
|
||||
(async function () {
|
||||
const responseClone = response.clone();
|
||||
|
||||
sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
requestId,
|
||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||
type: responseClone.type,
|
||||
status: responseClone.status,
|
||||
statusText: responseClone.statusText,
|
||||
body: responseClone.body,
|
||||
headers: Object.fromEntries(responseClone.headers.entries()),
|
||||
},
|
||||
},
|
||||
[responseClone.body]
|
||||
);
|
||||
})();
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
// Resolve the main client for the given event.
|
||||
// Client that issues a request doesn't necessarily equal the client
|
||||
// that registered the worker. It's with the latter the worker should
|
||||
// communicate with during the response resolving phase.
|
||||
async function resolveMainClient(event) {
|
||||
const client = await self.clients.get(event.clientId);
|
||||
|
||||
if (client?.frameType === 'top-level') {
|
||||
return client;
|
||||
}
|
||||
|
||||
const allClients = await self.clients.matchAll({
|
||||
type: 'window',
|
||||
function serializeHeaders(headers) {
|
||||
const reqHeaders = {};
|
||||
headers.forEach((value, name) => {
|
||||
reqHeaders[name] = reqHeaders[name] ? [].concat(reqHeaders[name]).concat(value) : value;
|
||||
});
|
||||
|
||||
return allClients
|
||||
.filter((client) => {
|
||||
// Get only those clients that are currently visible.
|
||||
return client.visibilityState === 'visible';
|
||||
})
|
||||
.find((client) => {
|
||||
// Find the client ID that's recorded in the
|
||||
// set of clients that have registered the worker.
|
||||
return activeClientIds.has(client.id);
|
||||
});
|
||||
return reqHeaders;
|
||||
}
|
||||
|
||||
async function getResponse(event, client, requestId) {
|
||||
const { request } = event;
|
||||
|
||||
// Clone the request because it might've been already used
|
||||
// (i.e. its body has been read and sent to the client).
|
||||
const requestClone = request.clone();
|
||||
|
||||
function passthrough() {
|
||||
const headers = Object.fromEntries(requestClone.headers.entries());
|
||||
|
||||
// Remove internal MSW request header so the passthrough request
|
||||
// complies with any potential CORS preflight checks on the server.
|
||||
// Some servers forbid unknown request headers.
|
||||
delete headers['x-msw-intention'];
|
||||
|
||||
return fetch(requestClone, { headers });
|
||||
}
|
||||
|
||||
// Bypass mocking when the client is not active.
|
||||
if (!client) {
|
||||
return passthrough();
|
||||
}
|
||||
|
||||
// Bypass initial page load requests (i.e. static assets).
|
||||
// The absence of the immediate/parent client in the map of the active clients
|
||||
// means that MSW hasn't dispatched the "MOCK_ACTIVATE" event yet
|
||||
// and is not ready to handle requests.
|
||||
if (!activeClientIds.has(client.id)) {
|
||||
return passthrough();
|
||||
}
|
||||
|
||||
// Bypass requests with the explicit bypass header.
|
||||
// Such requests can be issued by "ctx.fetch()".
|
||||
const mswIntention = request.headers.get('x-msw-intention');
|
||||
if (['bypass', 'passthrough'].includes(mswIntention)) {
|
||||
return passthrough();
|
||||
}
|
||||
|
||||
// Notify the client that a request has been intercepted.
|
||||
const requestBuffer = await request.arrayBuffer();
|
||||
const clientMessage = await sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'REQUEST',
|
||||
payload: {
|
||||
id: requestId,
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: requestBuffer,
|
||||
keepalive: request.keepalive,
|
||||
},
|
||||
},
|
||||
[requestBuffer]
|
||||
);
|
||||
|
||||
switch (clientMessage.type) {
|
||||
case 'MOCK_RESPONSE': {
|
||||
return respondWithMock(clientMessage.data);
|
||||
}
|
||||
|
||||
case 'MOCK_NOT_FOUND': {
|
||||
return passthrough();
|
||||
}
|
||||
}
|
||||
|
||||
return passthrough();
|
||||
}
|
||||
|
||||
function sendToClient(client, message, transferrables = []) {
|
||||
function sendToClient(client, message) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const channel = new MessageChannel();
|
||||
|
||||
@@ -260,25 +302,27 @@ function sendToClient(client, message, transferrables = []) {
|
||||
resolve(event.data);
|
||||
};
|
||||
|
||||
client.postMessage(message, [channel.port2].concat(transferrables.filter(Boolean)));
|
||||
client.postMessage(JSON.stringify(message), [channel.port2]);
|
||||
});
|
||||
}
|
||||
|
||||
async function respondWithMock(response) {
|
||||
// Setting response status code to 0 is a no-op.
|
||||
// However, when responding with a "Response.error()", the produced Response
|
||||
// instance will have status code set to 0. Since it's not possible to create
|
||||
// a Response instance with status code 0, handle that use-case separately.
|
||||
if (response.status === 0) {
|
||||
return Response.error();
|
||||
}
|
||||
|
||||
const mockedResponse = new Response(response.body, response);
|
||||
|
||||
Reflect.defineProperty(mockedResponse, IS_MOCKED_RESPONSE, {
|
||||
value: true,
|
||||
enumerable: true,
|
||||
function delayPromise(cb, duration) {
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => resolve(cb()), duration);
|
||||
});
|
||||
}
|
||||
|
||||
function respondWithMock(clientMessage) {
|
||||
return new Response(clientMessage.payload.body, {
|
||||
...clientMessage.payload,
|
||||
headers: clientMessage.payload.headers,
|
||||
});
|
||||
}
|
||||
|
||||
function uuidv4() {
|
||||
return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function (c) {
|
||||
const r = (Math.random() * 16) | 0;
|
||||
const v = c == 'x' ? r : (r & 0x3) | 0x8;
|
||||
return v.toString(16);
|
||||
});
|
||||
|
||||
return mockedResponse;
|
||||
}
|
||||
|
||||
19
.vscode.example/launch.json
Normal file
19
.vscode.example/launch.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Launch",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "debug",
|
||||
"program": "${workspaceRoot}/api/cmd/portainer",
|
||||
"cwd": "${workspaceRoot}",
|
||||
"env": {},
|
||||
"showLog": true,
|
||||
"args": ["--data", "${env:HOME}/portainer-data", "--assets", "${workspaceRoot}/dist"]
|
||||
}
|
||||
]
|
||||
}
|
||||
182
.vscode.example/portainer.code-snippets
Normal file
182
.vscode.example/portainer.code-snippets
Normal file
@@ -0,0 +1,182 @@
|
||||
{
|
||||
// Place your portainer workspace snippets here. Each snippet is defined under a snippet name and has a scope, prefix, body and
|
||||
// description. Add comma separated ids of the languages where the snippet is applicable in the scope field. If scope
|
||||
// is left empty or omitted, the snippet gets applied to all languages. The prefix is what is
|
||||
// used to trigger the snippet and the body will be expanded and inserted. Possible variables are:
|
||||
// $1, $2 for tab stops, $0 for the final cursor position, and ${1:label}, ${2:another} for placeholders.
|
||||
// Placeholders with the same ids are connected.
|
||||
// Example:
|
||||
// "Print to console": {
|
||||
// "scope": "javascript,typescript",
|
||||
// "prefix": "log",
|
||||
// "body": [
|
||||
// "console.log('$1');",
|
||||
// "$2"
|
||||
// ],
|
||||
// "description": "Log output to console"
|
||||
// }
|
||||
"Component": {
|
||||
"scope": "javascript",
|
||||
"prefix": "mycomponent",
|
||||
"description": "Dummy Angularjs Component",
|
||||
"body": [
|
||||
"import angular from 'angular';",
|
||||
"import controller from './${TM_FILENAME_BASE}Controller'",
|
||||
"",
|
||||
"angular.module('portainer.${TM_DIRECTORY/.*\\/app\\/([^\\/]*)(\\/.*)?$/$1/}').component('$TM_FILENAME_BASE', {",
|
||||
" templateUrl: './$TM_FILENAME_BASE.html',",
|
||||
" controller,",
|
||||
"});",
|
||||
""
|
||||
]
|
||||
},
|
||||
"Controller": {
|
||||
"scope": "javascript",
|
||||
"prefix": "mycontroller",
|
||||
"body": [
|
||||
"class ${TM_FILENAME_BASE/(.*)/${1:/capitalize}/} {",
|
||||
"\t/* @ngInject */",
|
||||
"\tconstructor($0) {",
|
||||
"\t}",
|
||||
"}",
|
||||
"",
|
||||
"export default ${TM_FILENAME_BASE/(.*)/${1:/capitalize}/};"
|
||||
],
|
||||
"description": "Dummy ES6+ controller"
|
||||
},
|
||||
"Service": {
|
||||
"scope": "javascript",
|
||||
"prefix": "myservice",
|
||||
"description": "Dummy ES6+ service",
|
||||
"body": [
|
||||
"import angular from 'angular';",
|
||||
"import PortainerError from 'Portainer/error';",
|
||||
"",
|
||||
"class $1 {",
|
||||
" /* @ngInject */",
|
||||
" constructor(\\$async, $0) {",
|
||||
" this.\\$async = \\$async;",
|
||||
"",
|
||||
" this.getAsync = this.getAsync.bind(this);",
|
||||
" this.getAllAsync = this.getAllAsync.bind(this);",
|
||||
" this.createAsync = this.createAsync.bind(this);",
|
||||
" this.updateAsync = this.updateAsync.bind(this);",
|
||||
" this.deleteAsync = this.deleteAsync.bind(this);",
|
||||
" }",
|
||||
"",
|
||||
" /**",
|
||||
" * GET",
|
||||
" */",
|
||||
" async getAsync() {",
|
||||
" try {",
|
||||
"",
|
||||
" } catch (err) {",
|
||||
" throw new PortainerError('', err);",
|
||||
" }",
|
||||
" }",
|
||||
"",
|
||||
" async getAllAsync() {",
|
||||
" try {",
|
||||
"",
|
||||
" } catch (err) {",
|
||||
" throw new PortainerError('', err);",
|
||||
" }",
|
||||
" }",
|
||||
"",
|
||||
" get() {",
|
||||
" if () {",
|
||||
" return this.\\$async(this.getAsync);",
|
||||
" }",
|
||||
" return this.\\$async(this.getAllAsync);",
|
||||
" }",
|
||||
"",
|
||||
" /**",
|
||||
" * CREATE",
|
||||
" */",
|
||||
" async createAsync() {",
|
||||
" try {",
|
||||
"",
|
||||
" } catch (err) {",
|
||||
" throw new PortainerError('', err);",
|
||||
" }",
|
||||
" }",
|
||||
"",
|
||||
" create() {",
|
||||
" return this.\\$async(this.createAsync);",
|
||||
" }",
|
||||
"",
|
||||
" /**",
|
||||
" * UPDATE",
|
||||
" */",
|
||||
" async updateAsync() {",
|
||||
" try {",
|
||||
"",
|
||||
" } catch (err) {",
|
||||
" throw new PortainerError('', err);",
|
||||
" }",
|
||||
" }",
|
||||
"",
|
||||
" update() {",
|
||||
" return this.\\$async(this.updateAsync);",
|
||||
" }",
|
||||
"",
|
||||
" /**",
|
||||
" * DELETE",
|
||||
" */",
|
||||
" async deleteAsync() {",
|
||||
" try {",
|
||||
"",
|
||||
" } catch (err) {",
|
||||
" throw new PortainerError('', err);",
|
||||
" }",
|
||||
" }",
|
||||
"",
|
||||
" delete() {",
|
||||
" return this.\\$async(this.deleteAsync);",
|
||||
" }",
|
||||
"}",
|
||||
"",
|
||||
"export default $1;",
|
||||
"angular.module('portainer.${TM_DIRECTORY/.*\\/app\\/([^\\/]*)(\\/.*)?$/$1/}').service('$1', $1);"
|
||||
]
|
||||
},
|
||||
"swagger-api-doc": {
|
||||
"prefix": "swapi",
|
||||
"scope": "go",
|
||||
"description": "Snippet for a api doc",
|
||||
"body": [
|
||||
"// @id ",
|
||||
"// @summary ",
|
||||
"// @description ",
|
||||
"// @description **Access policy**: ",
|
||||
"// @tags ",
|
||||
"// @security ApiKeyAuth",
|
||||
"// @security jwt",
|
||||
"// @accept json",
|
||||
"// @produce json",
|
||||
"// @param id path int true \"identifier\"",
|
||||
"// @param body body Object true \"details\"",
|
||||
"// @success 200 {object} portainer. \"Success\"",
|
||||
"// @success 204 \"Success\"",
|
||||
"// @failure 400 \"Invalid request\"",
|
||||
"// @failure 403 \"Permission denied\"",
|
||||
"// @failure 404 \" not found\"",
|
||||
"// @failure 500 \"Server error\"",
|
||||
"// @router /{id} [get]"
|
||||
]
|
||||
},
|
||||
"analytics": {
|
||||
"prefix": "nlt",
|
||||
"body": ["analytics-on", "analytics-category=\"$1\"", "analytics-event=\"$2\""],
|
||||
"description": "analytics"
|
||||
},
|
||||
"analytics-if": {
|
||||
"prefix": "nltf",
|
||||
"body": ["analytics-if=\"$1\""],
|
||||
"description": "analytics"
|
||||
},
|
||||
"analytics-metadata": {
|
||||
"prefix": "nltm",
|
||||
"body": "analytics-properties=\"{ metadata: { $1 } }\""
|
||||
}
|
||||
}
|
||||
8
.vscode.example/settings.json
Normal file
8
.vscode.example/settings.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"go.lintTool": "golangci-lint",
|
||||
"go.lintFlags": ["--fast", "-E", "exportloopref"],
|
||||
"gopls": {
|
||||
"build.expandWorkspaceToModule": false
|
||||
},
|
||||
"gitlens.advanced.blame.customArguments": ["--ignore-revs-file", ".git-blame-ignore-revs"]
|
||||
}
|
||||
44
CLAUDE.md
44
CLAUDE.md
@@ -1,44 +0,0 @@
|
||||
# Portainer Community Edition
|
||||
|
||||
Open-source container management platform with full Docker and Kubernetes support.
|
||||
|
||||
see also:
|
||||
|
||||
- docs/guidelines/server-architecture.md
|
||||
- docs/guidelines/go-conventions.md
|
||||
- docs/guidelines/typescript-conventions.md
|
||||
|
||||
## Package Manager
|
||||
|
||||
- **PNPM** 10+ (for frontend)
|
||||
- **Go** 1.25.7 (for backend)
|
||||
|
||||
## Build Commands
|
||||
|
||||
```bash
|
||||
# Full build
|
||||
make build # Build both client and server
|
||||
make build-client # Build React/AngularJS frontend
|
||||
make build-server # Build Go binary
|
||||
make build-image # Build Docker image
|
||||
|
||||
# Development
|
||||
make dev # Run both in dev mode
|
||||
make dev-client # Start webpack-dev-server (port 8999)
|
||||
make dev-server # Run containerized Go server
|
||||
|
||||
pnpm run dev # Webpack dev server
|
||||
pnpm run build # Build frontend with webpack
|
||||
pnpm run test # Run frontend tests
|
||||
|
||||
# Testing
|
||||
make test # All tests (backend + frontend)
|
||||
make test-server # Backend tests only
|
||||
make lint # Lint all code
|
||||
make format # Format code
|
||||
```
|
||||
|
||||
## Development Servers
|
||||
|
||||
- Frontend: http://localhost:8999
|
||||
- Backend: http://localhost:9000 (HTTP) / https://localhost:9443 (HTTPS)
|
||||
@@ -25,7 +25,7 @@ Each commit message should include a **type**, a **scope** and a **subject**:
|
||||
<type>(<scope>): <subject>
|
||||
```
|
||||
|
||||
Lines should not exceed 100 characters. This allows the message to be easier to read on GitHub as well as in various git tools and produces a nice, neat commit log ie:
|
||||
Lines should not exceed 100 characters. This allows the message to be easier to read on github as well as in various git tools and produces a nice, neat commit log ie:
|
||||
|
||||
```
|
||||
#271 feat(containers): add exposed ports in the containers view
|
||||
@@ -63,7 +63,7 @@ The subject contains succinct description of the change:
|
||||
|
||||
## Contribution process
|
||||
|
||||
Our contribution process is described below. Some of the steps can be visualized inside GitHub via specific `status/` labels, such as `status/1-functional-review` or `status/2-technical-review`.
|
||||
Our contribution process is described below. Some of the steps can be visualized inside Github via specific `status/` labels, such as `status/1-functional-review` or `status/2-technical-review`.
|
||||
|
||||
### Bug report
|
||||
|
||||
@@ -77,44 +77,32 @@ The feature request process is similar to the bug report process but has an extr
|
||||
|
||||
## Build and run Portainer locally
|
||||
|
||||
Ensure you have Docker, Node.js, pnpm, and Golang installed in the correct versions.
|
||||
Ensure you have Docker, Node.js, yarn, and Golang installed in the correct versions.
|
||||
|
||||
Install dependencies:
|
||||
Install dependencies with yarn:
|
||||
|
||||
```sh
|
||||
$ make deps
|
||||
$ yarn
|
||||
```
|
||||
|
||||
Then build and run the project in a Docker container:
|
||||
|
||||
```sh
|
||||
$ make dev
|
||||
$ yarn start
|
||||
```
|
||||
|
||||
Portainer server can now be accessed at <https://localhost:9443>. and UI dev server runs on <http://localhost:8999>.
|
||||
Portainer can now be accessed at <https://localhost:9443>.
|
||||
|
||||
if you want to build the project you can run:
|
||||
Find more detailed steps at <https://documentation.portainer.io/contributing/instructions/>.
|
||||
|
||||
```sh
|
||||
make build-all
|
||||
```
|
||||
### Build customisation
|
||||
|
||||
For additional make commands, run `make help`.
|
||||
|
||||
Find more detailed steps at <https://docs.portainer.io/contribute/build>.
|
||||
|
||||
### Build customization
|
||||
|
||||
You can customize the following settings:
|
||||
You can customise the following settings:
|
||||
|
||||
- `PORTAINER_DATA`: The host dir or volume name used by portainer (default is `/tmp/portainer`, which won't persist over reboots).
|
||||
- `PORTAINER_PROJECT`: The root dir of the repository - `${portainerRoot}/dist/` is imported into the container to get the build artifacts and external tools (defaults to `your current dir`).
|
||||
- `PORTAINER_FLAGS`: a list of flags to be used on the portainer commandline, in the form `--admin-password=<pwd hash> --feat fdo=false --feat open-amt` (default: `""`).
|
||||
|
||||
## Testing your build
|
||||
|
||||
The `--log-level=DEBUG` flag can be passed to the Portainer container in order to provide additional debug output which may be useful when troubleshooting your builds. Please note that this flag was originally intended for internal use and as such the format, functionality and output may change between releases without warning.
|
||||
|
||||
## Adding api docs
|
||||
|
||||
When adding a new resource (or a route handler), we should add a new tag to api/http/handler/handler.go#L136 like this:
|
||||
|
||||
118
Makefile
118
Makefile
@@ -1,118 +0,0 @@
|
||||
# build target, can be one of "production", "testing", "development"
|
||||
ENV=development
|
||||
WEBPACK_CONFIG=webpack/webpack.$(ENV).js
|
||||
TAG=local
|
||||
|
||||
SWAG=go run github.com/swaggo/swag/cmd/swag@v1.16.2
|
||||
GOTESTSUM=go run gotest.tools/gotestsum@latest
|
||||
|
||||
# Don't change anything below this line unless you know what you're doing
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
|
||||
##@ Building
|
||||
.PHONY: all init-dist build-storybook build build-client build-server build-image devops
|
||||
init-dist:
|
||||
@mkdir -p dist
|
||||
|
||||
all: tidy deps build-server build-client ## Build the client, server and download external dependancies (doesn't build an image)
|
||||
|
||||
build-all: all ## Alias for the 'all' target (used by CI)
|
||||
|
||||
build-client: init-dist ## Build the client
|
||||
export NODE_ENV=$(ENV) && pnpm run build --config $(WEBPACK_CONFIG)
|
||||
|
||||
build-server: init-dist ## Build the server binary
|
||||
./build/build_binary.sh "$(PLATFORM)" "$(ARCH)"
|
||||
|
||||
build-image: build-all ## Build the Portainer image locally
|
||||
docker buildx build --load -t portainerci/portainer-ce:$(TAG) -f build/linux/Dockerfile .
|
||||
|
||||
build-storybook: ## Build and serve the storybook files
|
||||
pnpm run storybook:build
|
||||
|
||||
##@ Build dependencies
|
||||
.PHONY: deps server-deps client-deps tidy
|
||||
deps: server-deps client-deps ## Download all client and server build dependancies
|
||||
|
||||
server-deps: init-dist ## Download dependant server binaries
|
||||
@./build/download_binaries.sh $(PLATFORM) $(ARCH)
|
||||
|
||||
client-deps: ## Install client dependencies
|
||||
pnpm install
|
||||
|
||||
tidy: ## Tidy up the go.mod file
|
||||
@go mod tidy
|
||||
|
||||
##@ Cleanup
|
||||
.PHONY: clean
|
||||
clean: ## Remove all build and download artifacts
|
||||
@echo "Clearing the dist directory..."
|
||||
@rm -rf dist/*
|
||||
|
||||
##@ Testing
|
||||
.PHONY: test test-client test-server
|
||||
test: test-server test-client ## Run all tests
|
||||
|
||||
test-client: ## Run client tests
|
||||
pnpm run test $(ARGS) --coverage
|
||||
|
||||
test-server: ## Run server tests
|
||||
$(GOTESTSUM) --format pkgname-and-test-fails --format-hide-empty-pkg --hide-summary skipped -- -cover -covermode=atomic -coverprofile=coverage.out ./...
|
||||
|
||||
##@ Dev
|
||||
.PHONY: dev dev-client dev-server
|
||||
dev: ## Run both the client and server in development mode
|
||||
make dev-server
|
||||
make dev-client
|
||||
|
||||
dev-client: ## Run the client in development mode
|
||||
pnpm install && pnpm run dev
|
||||
|
||||
dev-server: build-server ## Run the server in development mode
|
||||
@./dev/run_container.sh
|
||||
|
||||
dev-server-podman: build-server ## Run the server in development mode
|
||||
@./dev/run_container_podman.sh
|
||||
|
||||
##@ Format
|
||||
.PHONY: format format-client format-server
|
||||
|
||||
format: format-client format-server ## Format all code
|
||||
|
||||
format-client: ## Format client code
|
||||
pnpm run format
|
||||
|
||||
format-server: ## Format server code
|
||||
go fmt ./...
|
||||
|
||||
##@ Lint
|
||||
.PHONY: lint lint-client lint-server
|
||||
lint: lint-client lint-server ## Lint all code
|
||||
|
||||
lint-client: ## Lint client code
|
||||
pnpm run lint
|
||||
|
||||
lint-server: tidy ## Lint server code
|
||||
golangci-lint run --timeout=10m -c .golangci.yaml
|
||||
golangci-lint run --timeout=10m --new-from-rev=HEAD~ -c .golangci-forward.yaml
|
||||
|
||||
##@ Extension
|
||||
.PHONY: dev-extension
|
||||
dev-extension: build-server build-client ## Run the extension in development mode
|
||||
make local -f build/docker-extension/Makefile
|
||||
|
||||
##@ Docs
|
||||
.PHONY: docs-build docs-validate docs-clean docs-validate-clean
|
||||
docs-build: init-dist ## Build docs
|
||||
go mod download -x
|
||||
cd api && $(SWAG) init -o "../dist/docs" -ot "yaml" -g ./http/handler/handler.go --parseDependency --parseInternal --parseDepth 2 -p pascalcase --markdownFiles ./
|
||||
|
||||
docs-validate: docs-build ## Validate docs
|
||||
pnpm swagger2openapi --warnOnly dist/docs/swagger.yaml -o dist/docs/openapi.yaml
|
||||
pnpm swagger-cli validate dist/docs/openapi.yaml
|
||||
|
||||
##@ Helpers
|
||||
.PHONY: help
|
||||
help: ## Display this help
|
||||
@awk 'BEGIN {FS = ":.*##"; printf "Usage:\n make \033[36m<target>\033[0m\n"} /^[a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
|
||||
38
README.md
38
README.md
@@ -8,55 +8,65 @@ Portainer consists of a single container that can run on any cluster. It can be
|
||||
|
||||
**Portainer Business Edition** builds on the open-source base and includes a range of advanced features and functions (like RBAC and Support) that are specific to the needs of business users.
|
||||
|
||||
- [Compare Portainer CE and Compare Portainer BE](https://www.portainer.io/features)
|
||||
- [Take3 – get 3 free nodes of Portainer Business for as long as you want them](https://www.portainer.io/take-3)
|
||||
- [Portainer BE install guide](https://academy.portainer.io/install/)
|
||||
- [Compare Portainer CE and Compare Portainer BE](https://portainer.io/products)
|
||||
- [Take5 – get 5 free nodes of Portainer Business for as long as you want them](https://portainer.io/pricing/take5)
|
||||
- [Portainer BE install guide](https://install.portainer.io)
|
||||
|
||||
## Demo
|
||||
|
||||
You can try out the public demo instance: http://demo.portainer.io/ (login with the username **admin** and the password **tryportainer**).
|
||||
|
||||
Please note that the public demo cluster is **reset every 15min**.
|
||||
|
||||
## Latest Version
|
||||
|
||||
Portainer CE is updated regularly. We aim to do an update release every couple of months.
|
||||
|
||||
[](https://github.com/portainer/portainer/releases/latest)
|
||||
**The latest version of Portainer is 2.9.x**. Portainer is on version 2, the second number denotes the month of release.
|
||||
|
||||
## Getting started
|
||||
|
||||
- [Deploy Portainer](https://docs.portainer.io/start/install-ce)
|
||||
- [Documentation](https://docs.portainer.io)
|
||||
- [Contribute to the project](https://docs.portainer.io/contribute/contribute)
|
||||
- [Deploy Portainer](https://docs.portainer.io/v/ce-2.9/start/install)
|
||||
- [Documentation](https://documentation.portainer.io)
|
||||
- [Contribute to the project](https://documentation.portainer.io/contributing/instructions/)
|
||||
|
||||
## Features & Functions
|
||||
|
||||
View [this](https://www.portainer.io/features) table to see all of the Portainer CE functionality and compare to Portainer Business.
|
||||
View [this](https://www.portainer.io/products) table to see all of the Portainer CE functionality and compare to Portainer Business.
|
||||
|
||||
- [Portainer CE for Docker / Docker Swarm](https://www.portainer.io/solutions/docker)
|
||||
- [Portainer CE for Kubernetes](https://www.portainer.io/solutions/kubernetes-ui)
|
||||
- [Portainer CE for Azure ACI](https://www.portainer.io/solutions/serverless-containers)
|
||||
|
||||
## Getting help
|
||||
|
||||
Portainer CE is an open source project and is supported by the community. You can buy a supported version of Portainer at portainer.io
|
||||
|
||||
Learn more about Portainer's community support channels [here.](https://www.portainer.io/resources/get-help/get-support)
|
||||
Learn more about Portainer's community support channels [here.](https://www.portainer.io/community_help)
|
||||
|
||||
- Issues: https://github.com/portainer/portainer/issues
|
||||
- Slack (chat): [https://portainer.io/slack](https://portainer.io/slack)
|
||||
|
||||
You can join the Portainer Community by visiting [https://www.portainer.io/join-our-community](https://www.portainer.io/join-our-community). This will give you advance notice of events, content and other related Portainer content.
|
||||
You can join the Portainer Community by visiting community.portainer.io. This will give you advance notice of events, content and other related Portainer content.
|
||||
|
||||
## Reporting bugs and contributing
|
||||
|
||||
- Want to report a bug or request a feature? Please open [an issue](https://github.com/portainer/portainer/issues/new).
|
||||
- Want to help us build **_portainer_**? Follow our [contribution guidelines](https://docs.portainer.io/contribute/contribute) to build it locally and make a pull request.
|
||||
- Want to help us build **_portainer_**? Follow our [contribution guidelines](https://documentation.portainer.io/contributing/instructions/) to build it locally and make a pull request.
|
||||
|
||||
## Security
|
||||
|
||||
For information about reporting security vulnerabilities, please see our [Security Policy](SECURITY.md).
|
||||
- Here at Portainer, we believe in [responsible disclosure](https://en.wikipedia.org/wiki/Responsible_disclosure) of security issues. If you have found a security issue, please report it to <security@portainer.io>.
|
||||
|
||||
## Work for us
|
||||
|
||||
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to success@portainer.io with your details and/or visit our [careers page](https://apply.workable.com/portainer/).
|
||||
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to info@portainer.io with your details and/or visit our [careers page](https://portainer.io/careers).
|
||||
|
||||
## Privacy
|
||||
|
||||
**To make sure we focus our development effort in the right places we need to know which features get used most often. To give us this information we use [Matomo Analytics](https://matomo.org/), which is hosted in Germany and is fully GDPR compliant.**
|
||||
|
||||
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/legal/privacy-policy). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
|
||||
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/documentation/in-app-analytics-and-privacy-policy/). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
|
||||
|
||||
## Limitations
|
||||
|
||||
|
||||
61
SECURITY.md
61
SECURITY.md
@@ -1,61 +0,0 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Portainer maintains both Short-Term Support (STS) and Long-Term Support (LTS) versions in accordance with our official [Portainer Lifecycle Policy](https://docs.portainer.io/start/lifecycle).
|
||||
|
||||
| Version Type | Support Status |
|
||||
| --- | --- |
|
||||
| LTS (Long-Term Support) | Supported for critical security fixes |
|
||||
| STS (Short-Term Support) | Supported until the next STS or LTS release |
|
||||
| Legacy / EOL | Not supported |
|
||||
|
||||
For a detailed breakdown of current versions and their specific End of Life (EOL) dates,
|
||||
please refer to the [Portainer Lifecycle Policy](https://docs.portainer.io/start/lifecycle).
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
The Portainer team takes the security of our products seriously. If you believe you have found a security vulnerability in any Portainer-owned repository, please report it to us responsibly.
|
||||
|
||||
**Please do not report security vulnerabilities via public GitHub issues.**
|
||||
|
||||
### Disclosure Process
|
||||
|
||||
1. **Report**: You can report in one of two ways:
|
||||
|
||||
- **GitHub**: Use the **Report a vulnerability** button on the **Security** tab of this repository.
|
||||
|
||||
- **Email**: Send your findings to security@portainer.io.
|
||||
|
||||
2. **Details**: To help us verify the issue, please include:
|
||||
|
||||
- A description of the vulnerability and its potential impact.
|
||||
|
||||
- Step-by-step instructions to reproduce the issue (e.g. proof-of-concept code, scripts, or screenshots).
|
||||
|
||||
- The version of the software and the environment in which it was found.
|
||||
|
||||
3. **Acknowledge**: We will acknowledge receipt of your report and provide an initial assessment.
|
||||
|
||||
4. **Resolution**: We will work to resolve the issue as quickly as possible. We request that you do not disclose the vulnerability publicly until we have released a fix and notified affected users.
|
||||
|
||||
## Our Commitment
|
||||
|
||||
If you follow the responsible disclosure process, we will:
|
||||
|
||||
- Respond to your report in a timely manner.
|
||||
|
||||
- Provide an estimated timeline for remediation.
|
||||
|
||||
- Notify you when the vulnerability has been patched.
|
||||
|
||||
- Give credit for the discovery (if desired) once the fix is public.
|
||||
|
||||
|
||||
We will make every effort to promptly address any security weaknesses. Security advisories and fixes will be published through GitHub Security Advisories and other channels as needed.
|
||||
|
||||
Thank you for helping keep Portainer and our community secure.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Contributing to Portainer](https://docs.portainer.io/contribute/contribute#contributing-to-the-portainer-ce-codebase)
|
||||
@@ -2,18 +2,19 @@ package adminmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
httperror "github.com/portainer/libhttp/error"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var logFatalf = log.Fatalf
|
||||
|
||||
const RedirectReasonAdminInitTimeout string = "AdminInitTimeout"
|
||||
|
||||
type Monitor struct {
|
||||
@@ -21,11 +22,11 @@ type Monitor struct {
|
||||
datastore dataservices.DataStore
|
||||
shutdownCtx context.Context
|
||||
cancellationFunc context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
adminInitDisabled bool
|
||||
}
|
||||
|
||||
// New creates a monitor that when started will wait for the timeout duration and then shutdown the application unless it has been initialized.
|
||||
// New creates a monitor that when started will wait for the timeout duration and then sends the timeout signal to disable the application
|
||||
func New(timeout time.Duration, datastore dataservices.DataStore, shutdownCtx context.Context) *Monitor {
|
||||
return &Monitor{
|
||||
timeout: timeout,
|
||||
@@ -48,29 +49,24 @@ func (m *Monitor) Start() {
|
||||
m.cancellationFunc = cancellationFunc
|
||||
|
||||
go func() {
|
||||
log.Debug().Msg("start initialization monitor")
|
||||
|
||||
log.Println("[DEBUG] [internal,init] [message: start initialization monitor ]")
|
||||
select {
|
||||
case <-time.After(m.timeout):
|
||||
initialized, err := m.WasInitialized()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("AdminMonitor failed to determine if Portainer is Initialized")
|
||||
return
|
||||
logFatalf("%s", err)
|
||||
}
|
||||
|
||||
if !initialized {
|
||||
log.Info().Msg("the Portainer instance timed out for security purposes, to re-enable your Portainer instance, you will need to restart Portainer")
|
||||
|
||||
log.Println("[INFO] [internal,init] The Portainer instance timed out for security purposes. To re-enable your Portainer instance, you will need to restart Portainer")
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.adminInitDisabled = true
|
||||
return
|
||||
}
|
||||
case <-cancellationCtx.Done():
|
||||
log.Debug().Msg("canceling initialization monitor")
|
||||
log.Println("[DEBUG] [internal,init] [message: canceling initialization monitor]")
|
||||
case <-m.shutdownCtx.Done():
|
||||
log.Debug().Msg("shutting down initialization monitor")
|
||||
log.Println("[DEBUG] [internal,init] [message: shutting down initialization monitor]")
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -83,7 +79,6 @@ func (m *Monitor) Stop() {
|
||||
if m.cancellationFunc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.cancellationFunc()
|
||||
m.cancellationFunc = nil
|
||||
}
|
||||
@@ -94,14 +89,12 @@ func (m *Monitor) WasInitialized() (bool, error) {
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return len(users) > 0, nil
|
||||
}
|
||||
|
||||
func (m *Monitor) WasInstanceDisabled() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.adminInitDisabled
|
||||
}
|
||||
|
||||
@@ -109,10 +102,12 @@ func (m *Monitor) WasInstanceDisabled() bool {
|
||||
// Otherwise, it will pass through the request to next
|
||||
func (m *Monitor) WithRedirect(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if m.WasInstanceDisabled() && strings.HasPrefix(r.RequestURI, "/api") && r.RequestURI != "/api/status" && r.RequestURI != "/api/settings/public" {
|
||||
w.Header().Set("redirect-reason", RedirectReasonAdminInitTimeout)
|
||||
httperror.WriteError(w, http.StatusSeeOther, "Administrator initialization timeout", nil)
|
||||
return
|
||||
if m.WasInstanceDisabled() {
|
||||
if strings.HasPrefix(r.RequestURI, "/api") && r.RequestURI != "/api/status" && r.RequestURI != "/api/settings/public" {
|
||||
w.Header().Set("redirect-reason", RedirectReasonAdminInitTimeout)
|
||||
httperror.WriteError(w, http.StatusSeeOther, "Administrator initialization timeout", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// GetAgentVersionAndPlatform returns the agent version and platform
|
||||
//
|
||||
// it sends a ping to the agent and parses the version and platform from the headers
|
||||
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
|
||||
httpCli := &http.Client{Timeout: 3 * time.Second}
|
||||
|
||||
if tlsConfig != nil {
|
||||
httpCli.Transport = &http.Transport{TLSClientConfig: tlsConfig}
|
||||
}
|
||||
|
||||
parsedURL, err := url.ParseURL(endpointUrl + "/ping")
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
parsedURL.Scheme = "https"
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, parsedURL.String(), nil)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
resp, err := httpCli.Do(req)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to close response body")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return 0, "", fmt.Errorf("Failed request with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
version := resp.Header.Get(portainer.PortainerAgentHeader)
|
||||
if version == "" {
|
||||
return 0, "", errors.New("Version Header is missing")
|
||||
}
|
||||
|
||||
agentPlatformHeader := resp.Header.Get(portainer.HTTPResponseAgentPlatform)
|
||||
if agentPlatformHeader == "" {
|
||||
return 0, "", errors.New("Agent Platform Header is missing")
|
||||
}
|
||||
|
||||
agentPlatformNumber, err := strconv.Atoi(agentPlatformHeader)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
if agentPlatformNumber == 0 {
|
||||
return 0, "", errors.New("Agent platform is invalid")
|
||||
}
|
||||
|
||||
return portainer.AgentPlatform(agentPlatformNumber), version, nil
|
||||
}
|
||||
@@ -1,17 +1,30 @@
|
||||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
// APIKeyService represents a service for managing API keys.
|
||||
type APIKeyService interface {
|
||||
HashRaw(rawKey string) string
|
||||
HashRaw(rawKey string) []byte
|
||||
GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error)
|
||||
GetAPIKey(apiKeyID portainer.APIKeyID) (*portainer.APIKey, error)
|
||||
GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey, error)
|
||||
GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error)
|
||||
GetDigestUserAndKey(digest []byte) (portainer.User, portainer.APIKey, error)
|
||||
UpdateAPIKey(apiKey *portainer.APIKey) error
|
||||
DeleteAPIKey(apiKeyID portainer.APIKeyID) error
|
||||
InvalidateUserKeyCache(userId portainer.UserID) bool
|
||||
}
|
||||
|
||||
// generateRandomKey generates a random key of specified length
|
||||
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
|
||||
func generateRandomKey(length int) []byte {
|
||||
k := make([]byte, length)
|
||||
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||
return nil
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
@@ -10,42 +10,40 @@ func Test_generateRandomKey(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLength int
|
||||
name string
|
||||
wantLenth int
|
||||
}{
|
||||
{
|
||||
name: "Generate a random key of length 16",
|
||||
wantLength: 16,
|
||||
name: "Generate a random key of length 16",
|
||||
wantLenth: 16,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 32",
|
||||
wantLength: 32,
|
||||
name: "Generate a random key of length 32",
|
||||
wantLenth: 32,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 64",
|
||||
wantLength: 64,
|
||||
name: "Generate a random key of length 64",
|
||||
wantLenth: 64,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 128",
|
||||
wantLength: 128,
|
||||
name: "Generate a random key of length 128",
|
||||
wantLenth: 128,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateRandomKey(tt.wantLength)
|
||||
is.Len(got, tt.wantLength)
|
||||
got := generateRandomKey(tt.wantLenth)
|
||||
is.Equal(tt.wantLenth, len(got))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Generated keys are unique", func(t *testing.T) {
|
||||
keys := make(map[string]bool)
|
||||
|
||||
for range 100 {
|
||||
key := GenerateRandomKey(8)
|
||||
for i := 0; i < 100; i++ {
|
||||
key := generateRandomKey(8)
|
||||
_, ok := keys[string(key)]
|
||||
is.False(ok)
|
||||
|
||||
keys[string(key)] = true
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,79 +1,69 @@
|
||||
package apikey
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
const DefaultAPIKeyCacheSize = 1024
|
||||
const defaultAPIKeyCacheSize = 1024
|
||||
|
||||
// entry is a tuple containing the user and API key associated to an API key digest
|
||||
type entry[T any] struct {
|
||||
user T
|
||||
type entry struct {
|
||||
user portainer.User
|
||||
apiKey portainer.APIKey
|
||||
}
|
||||
|
||||
type UserCompareFn[T any] func(T, portainer.UserID) bool
|
||||
|
||||
// ApiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
|
||||
// apiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
|
||||
// We store the api-key digest (keys) and the associated user and key-data (values) in the cache.
|
||||
// This is required because HTTP requests will contain only the api-key digest in the x-api-key request header;
|
||||
// digest value must be mapped to a portainer user (and respective key data) for validation.
|
||||
// This cache is used to avoid multiple database queries to retrieve these user/key associated to the digest.
|
||||
type ApiKeyCache[T any] struct {
|
||||
type apiKeyCache struct {
|
||||
// cache type [string]entry cache (key: string(digest), value: user/key entry)
|
||||
// note: []byte keys are not supported by golang-lru Cache
|
||||
cache *lru.Cache
|
||||
userCmpFn UserCompareFn[T]
|
||||
cache *lru.Cache
|
||||
}
|
||||
|
||||
// NewAPIKeyCache creates a new cache for API keys
|
||||
func NewAPIKeyCache[T any](cacheSize int, userCompareFn UserCompareFn[T]) *ApiKeyCache[T] {
|
||||
func NewAPIKeyCache(cacheSize int) *apiKeyCache {
|
||||
cache, _ := lru.New(cacheSize)
|
||||
|
||||
return &ApiKeyCache[T]{cache: cache, userCmpFn: userCompareFn}
|
||||
return &apiKeyCache{cache: cache}
|
||||
}
|
||||
|
||||
// Get returns the user/key associated to an api-key's digest
|
||||
// This is required because HTTP requests will contain the digest of the API key in header,
|
||||
// the digest value must be mapped to a portainer user.
|
||||
func (c *ApiKeyCache[T]) Get(digest string) (T, portainer.APIKey, bool) {
|
||||
val, ok := c.cache.Get(digest)
|
||||
func (c *apiKeyCache) Get(digest []byte) (portainer.User, portainer.APIKey, bool) {
|
||||
val, ok := c.cache.Get(string(digest))
|
||||
if !ok {
|
||||
var t T
|
||||
|
||||
return t, portainer.APIKey{}, false
|
||||
return portainer.User{}, portainer.APIKey{}, false
|
||||
}
|
||||
|
||||
tuple := val.(entry[T])
|
||||
tuple := val.(entry)
|
||||
|
||||
return tuple.user, tuple.apiKey, true
|
||||
}
|
||||
|
||||
// Set persists a user/key entry to the cache
|
||||
func (c *ApiKeyCache[T]) Set(digest string, user T, apiKey portainer.APIKey) {
|
||||
c.cache.Add(digest, entry[T]{
|
||||
func (c *apiKeyCache) Set(digest []byte, user portainer.User, apiKey portainer.APIKey) {
|
||||
c.cache.Add(string(digest), entry{
|
||||
user: user,
|
||||
apiKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete evicts a digest's user/key entry key from the cache
|
||||
func (c *ApiKeyCache[T]) Delete(digest string) {
|
||||
c.cache.Remove(digest)
|
||||
func (c *apiKeyCache) Delete(digest []byte) {
|
||||
c.cache.Remove(string(digest))
|
||||
}
|
||||
|
||||
// InvalidateUserKeyCache loops through all the api-keys associated to a user and removes them from the cache
|
||||
func (c *ApiKeyCache[T]) InvalidateUserKeyCache(userId portainer.UserID) bool {
|
||||
func (c *apiKeyCache) InvalidateUserKeyCache(userId portainer.UserID) bool {
|
||||
present := false
|
||||
|
||||
for _, k := range c.cache.Keys() {
|
||||
user, _, _ := c.Get(k.(string))
|
||||
if c.userCmpFn(user, userId) {
|
||||
user, _, _ := c.Get([]byte(k.(string)))
|
||||
if user.ID == userId {
|
||||
present = c.cache.Remove(k)
|
||||
}
|
||||
}
|
||||
|
||||
return present
|
||||
}
|
||||
|
||||
@@ -10,32 +10,32 @@ import (
|
||||
func Test_apiKeyCacheGet(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
|
||||
// pre-populate cache
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string(""), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string(""), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
|
||||
tests := []struct {
|
||||
digest string
|
||||
digest []byte
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
digest: "foo",
|
||||
digest: []byte("foo"),
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
digest: "",
|
||||
digest: []byte(""),
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
digest: "bar",
|
||||
digest: []byte("bar"),
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.digest, func(t *testing.T) {
|
||||
t.Run(string(test.digest), func(t *testing.T) {
|
||||
_, _, found := keyCache.Get(test.digest)
|
||||
is.Equal(test.found, found)
|
||||
})
|
||||
@@ -45,43 +45,43 @@ func Test_apiKeyCacheGet(t *testing.T) {
|
||||
func Test_apiKeyCacheSet(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
|
||||
// pre-populate cache
|
||||
keyCache.Set("bar", portainer.User{ID: 2}, portainer.APIKey{})
|
||||
keyCache.Set("foo", portainer.User{ID: 1}, portainer.APIKey{})
|
||||
keyCache.Set([]byte("bar"), portainer.User{ID: 2}, portainer.APIKey{})
|
||||
keyCache.Set([]byte("foo"), portainer.User{ID: 1}, portainer.APIKey{})
|
||||
|
||||
// overwrite existing entry
|
||||
keyCache.Set("foo", portainer.User{ID: 3}, portainer.APIKey{})
|
||||
keyCache.Set([]byte("foo"), portainer.User{ID: 3}, portainer.APIKey{})
|
||||
|
||||
val, ok := keyCache.cache.Get(string("bar"))
|
||||
is.True(ok)
|
||||
|
||||
tuple := val.(entry[portainer.User])
|
||||
tuple := val.(entry)
|
||||
is.Equal(portainer.User{ID: 2}, tuple.user)
|
||||
|
||||
val, ok = keyCache.cache.Get(string("foo"))
|
||||
is.True(ok)
|
||||
|
||||
tuple = val.(entry[portainer.User])
|
||||
tuple = val.(entry)
|
||||
is.Equal(portainer.User{ID: 3}, tuple.user)
|
||||
}
|
||||
|
||||
func Test_apiKeyCacheDelete(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
|
||||
t.Run("Delete an existing entry", func(t *testing.T) {
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.Delete("foo")
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.Delete([]byte("foo"))
|
||||
|
||||
_, ok := keyCache.cache.Get(string("foo"))
|
||||
is.False(ok)
|
||||
})
|
||||
|
||||
t.Run("Delete a non-existing entry", func(t *testing.T) {
|
||||
nonPanicFunc := func() { keyCache.Delete("non-existent-key") }
|
||||
nonPanicFunc := func() { keyCache.Delete([]byte("non-existent-key")) }
|
||||
is.NotPanics(nonPanicFunc)
|
||||
})
|
||||
}
|
||||
@@ -128,19 +128,19 @@ func Test_apiKeyCacheLRU(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
keyCache := NewAPIKeyCache(test.cacheLen, compareUser)
|
||||
keyCache := NewAPIKeyCache(test.cacheLen)
|
||||
|
||||
for _, key := range test.key {
|
||||
keyCache.Set(key, portainer.User{ID: 1}, portainer.APIKey{})
|
||||
keyCache.Set([]byte(key), portainer.User{ID: 1}, portainer.APIKey{})
|
||||
}
|
||||
|
||||
for _, key := range test.foundKeys {
|
||||
_, _, found := keyCache.Get(key)
|
||||
_, _, found := keyCache.Get([]byte(key))
|
||||
is.True(found, "Key %s not found", key)
|
||||
}
|
||||
|
||||
for _, key := range test.evictedKeys {
|
||||
_, _, found := keyCache.Get(key)
|
||||
_, _, found := keyCache.Get([]byte(key))
|
||||
is.False(found, "key %s should have been evicted", key)
|
||||
}
|
||||
})
|
||||
@@ -150,10 +150,10 @@ func Test_apiKeyCacheLRU(t *testing.T) {
|
||||
func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
|
||||
t.Run("Removes users keys from cache", func(t *testing.T) {
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
|
||||
ok := keyCache.InvalidateUserKeyCache(1)
|
||||
is.True(ok)
|
||||
@@ -163,8 +163,8 @@ func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Does not affect other keys", func(t *testing.T) {
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("bar"), entry[portainer.User]{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("bar"), entry{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
|
||||
|
||||
ok := keyCache.InvalidateUserKeyCache(1)
|
||||
is.True(ok)
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const portainerAPIKeyPrefix = "ptr_"
|
||||
@@ -21,45 +19,30 @@ var ErrInvalidAPIKey = errors.New("Invalid API key")
|
||||
type apiKeyService struct {
|
||||
apiKeyRepository dataservices.APIKeyRepository
|
||||
userRepository dataservices.UserService
|
||||
cache *ApiKeyCache[portainer.User]
|
||||
}
|
||||
|
||||
// GenerateRandomKey generates a random key of specified length
|
||||
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
|
||||
func GenerateRandomKey(length int) []byte {
|
||||
k := make([]byte, length)
|
||||
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
func compareUser(u portainer.User, id portainer.UserID) bool {
|
||||
return u.ID == id
|
||||
cache *apiKeyCache
|
||||
}
|
||||
|
||||
func NewAPIKeyService(apiKeyRepository dataservices.APIKeyRepository, userRepository dataservices.UserService) *apiKeyService {
|
||||
return &apiKeyService{
|
||||
apiKeyRepository: apiKeyRepository,
|
||||
userRepository: userRepository,
|
||||
cache: NewAPIKeyCache(DefaultAPIKeyCacheSize, compareUser),
|
||||
cache: NewAPIKeyCache(defaultAPIKeyCacheSize),
|
||||
}
|
||||
}
|
||||
|
||||
// HashRaw computes a hash digest of provided raw API key.
|
||||
func (a *apiKeyService) HashRaw(rawKey string) string {
|
||||
func (a *apiKeyService) HashRaw(rawKey string) []byte {
|
||||
hashDigest := sha256.Sum256([]byte(rawKey))
|
||||
|
||||
return base64.StdEncoding.EncodeToString(hashDigest[:])
|
||||
return hashDigest[:]
|
||||
}
|
||||
|
||||
// GenerateApiKey generates a raw API key for a user (for one-time display).
|
||||
// The generated API key is stored in the cache and database.
|
||||
func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error) {
|
||||
randKey := GenerateRandomKey(32)
|
||||
randKey := generateRandomKey(32)
|
||||
encodedRawAPIKey := base64.StdEncoding.EncodeToString(randKey)
|
||||
prefixedAPIKey := portainerAPIKeyPrefix + encodedRawAPIKey
|
||||
|
||||
hashDigest := a.HashRaw(prefixedAPIKey)
|
||||
|
||||
apiKey := &portainer.APIKey{
|
||||
@@ -70,7 +53,8 @@ func (a *apiKeyService) GenerateApiKey(user portainer.User, description string)
|
||||
Digest: hashDigest,
|
||||
}
|
||||
|
||||
if err := a.apiKeyRepository.Create(apiKey); err != nil {
|
||||
err := a.apiKeyRepository.CreateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrap(err, "Unable to create API key")
|
||||
}
|
||||
|
||||
@@ -82,7 +66,7 @@ func (a *apiKeyService) GenerateApiKey(user portainer.User, description string)
|
||||
|
||||
// GetAPIKey returns an API key by its ID.
|
||||
func (a *apiKeyService) GetAPIKey(apiKeyID portainer.APIKeyID) (*portainer.APIKey, error) {
|
||||
return a.apiKeyRepository.Read(apiKeyID)
|
||||
return a.apiKeyRepository.GetAPIKey(apiKeyID)
|
||||
}
|
||||
|
||||
// GetAPIKeys returns all the API keys associated to a user.
|
||||
@@ -92,7 +76,8 @@ func (a *apiKeyService) GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey,
|
||||
|
||||
// GetDigestUserAndKey returns the user and api-key associated to a specified hash digest.
|
||||
// A cache lookup is performed first; if the user/api-key is not found in the cache, respective database lookups are performed.
|
||||
func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error) {
|
||||
func (a *apiKeyService) GetDigestUserAndKey(digest []byte) (portainer.User, portainer.APIKey, error) {
|
||||
// get api key from cache if possible
|
||||
cachedUser, cachedKey, ok := a.cache.Get(digest)
|
||||
if ok {
|
||||
return cachedUser, cachedKey, nil
|
||||
@@ -103,7 +88,7 @@ func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, port
|
||||
return portainer.User{}, portainer.APIKey{}, errors.Wrap(err, "Unable to retrieve API key")
|
||||
}
|
||||
|
||||
user, err := a.userRepository.Read(apiKey.UserID)
|
||||
user, err := a.userRepository.User(apiKey.UserID)
|
||||
if err != nil {
|
||||
return portainer.User{}, portainer.APIKey{}, errors.Wrap(err, "Unable to retrieve digest user")
|
||||
}
|
||||
@@ -120,22 +105,21 @@ func (a *apiKeyService) UpdateAPIKey(apiKey *portainer.APIKey) error {
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to retrieve API key")
|
||||
}
|
||||
|
||||
a.cache.Set(apiKey.Digest, user, *apiKey)
|
||||
|
||||
return a.apiKeyRepository.Update(apiKey.ID, apiKey)
|
||||
return a.apiKeyRepository.UpdateAPIKey(apiKey)
|
||||
}
|
||||
|
||||
// DeleteAPIKey deletes an API key and removes the digest/api-key entry from the cache.
|
||||
func (a *apiKeyService) DeleteAPIKey(apiKeyID portainer.APIKeyID) error {
|
||||
apiKey, err := a.apiKeyRepository.Read(apiKeyID)
|
||||
// get api-key digest to remove from cache
|
||||
apiKey, err := a.apiKeyRepository.GetAPIKey(apiKeyID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, fmt.Sprintf("Unable to retrieve API key: %d", apiKeyID))
|
||||
}
|
||||
|
||||
// delete the user/api-key from cache
|
||||
a.cache.Delete(apiKey.Digest)
|
||||
|
||||
return a.apiKeyRepository.Delete(apiKeyID)
|
||||
return a.apiKeyRepository.DeleteAPIKey(apiKeyID)
|
||||
}
|
||||
|
||||
func (a *apiKeyService) InvalidateUserKeyCache(userId portainer.UserID) bool {
|
||||
|
||||
@@ -2,18 +2,14 @@ package apikey
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_SatisfiesAPIKeyServiceInterface(t *testing.T) {
|
||||
@@ -24,14 +20,15 @@ func Test_SatisfiesAPIKeyServiceInterface(t *testing.T) {
|
||||
func Test_GenerateApiKey(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
service := NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
||||
t.Run("Successfully generates API key", func(t *testing.T) {
|
||||
desc := "test-1"
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, desc)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.NotEmpty(rawKey)
|
||||
is.NotEmpty(apiKey)
|
||||
is.Equal(desc, apiKey.Description)
|
||||
@@ -39,7 +36,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
|
||||
t.Run("Api key prefix is 7 chars", func(t *testing.T) {
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
is.Equal(rawKey[:7], apiKey.Prefix)
|
||||
is.Len(apiKey.Prefix, 7)
|
||||
@@ -47,7 +44,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
|
||||
t.Run("Api key has 'ptr_' as prefix", func(t *testing.T) {
|
||||
rawKey, _, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
is.Equal(portainerAPIKeyPrefix, "ptr_")
|
||||
is.True(strings.HasPrefix(rawKey, "ptr_"))
|
||||
@@ -56,7 +53,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
t.Run("Successfully caches API key", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-3")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
userFromCache, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
@@ -66,28 +63,29 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
|
||||
t.Run("Decoded raw api-key digest matches generated digest", func(t *testing.T) {
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-4")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
generatedDigest := sha256.Sum256([]byte(rawKey))
|
||||
|
||||
is.Equal(apiKey.Digest, base64.StdEncoding.EncodeToString(generatedDigest[:]))
|
||||
is.Equal(apiKey.Digest, generatedDigest[:])
|
||||
})
|
||||
}
|
||||
|
||||
func Test_GetAPIKey(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
service := NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
||||
t.Run("Successfully returns all API keys", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
apiKeyGot, err := service.GetAPIKey(apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
is.Equal(apiKey, apiKeyGot)
|
||||
})
|
||||
@@ -96,19 +94,20 @@ func Test_GetAPIKey(t *testing.T) {
|
||||
func Test_GetAPIKeys(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
service := NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
||||
t.Run("Successfully returns all API keys", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, _, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
_, _, err = service.GenerateApiKey(user, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
keys, err := service.GetAPIKeys(user.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Len(keys, 2)
|
||||
})
|
||||
}
|
||||
@@ -116,17 +115,18 @@ func Test_GetAPIKeys(t *testing.T) {
|
||||
func Test_GetDigestUserAndKey(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
service := NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
||||
t.Run("Successfully returns user and api key associated to digest", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(user, userGot)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
})
|
||||
@@ -134,10 +134,10 @@ func Test_GetDigestUserAndKey(t *testing.T) {
|
||||
t.Run("Successfully caches user and api key associated to digest", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(user, userGot)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
|
||||
@@ -151,34 +151,34 @@ func Test_GetDigestUserAndKey(t *testing.T) {
|
||||
func Test_UpdateAPIKey(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
service := NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
||||
t.Run("Successfully updates the api-key LastUsed time", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
|
||||
err := store.User().Create(&user)
|
||||
require.NoError(t, err)
|
||||
|
||||
store.User().Create(&user)
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-x")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
apiKey.LastUsed = time.Now().UTC().Unix()
|
||||
err = service.UpdateAPIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
log.Debug().Str("wanted", fmt.Sprintf("%+v", apiKey)).Str("got", fmt.Sprintf("%+v", apiKeyGot)).Msg("")
|
||||
log.Println(apiKey)
|
||||
log.Println(apiKeyGot)
|
||||
|
||||
is.Equal(apiKey.LastUsed, apiKeyGot.LastUsed)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Successfully updates api-key in cache upon api-key update", func(t *testing.T) {
|
||||
_, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
@@ -188,7 +188,7 @@ func Test_UpdateAPIKey(t *testing.T) {
|
||||
is.NotEqual(*apiKey, apiKeyFromCache)
|
||||
|
||||
err = service.UpdateAPIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, updatedAPIKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
@@ -199,37 +199,38 @@ func Test_UpdateAPIKey(t *testing.T) {
|
||||
func Test_DeleteAPIKey(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
service := NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
||||
t.Run("Successfully updates the api-key", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
|
||||
err = service.DeleteAPIKey(apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, _, err = service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.Error(t, err)
|
||||
is.Error(err)
|
||||
})
|
||||
|
||||
t.Run("Successfully removes api-key from cache upon deletion", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
is.Equal(*apiKey, apiKeyFromCache)
|
||||
|
||||
err = service.DeleteAPIKey(apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, _, ok = service.cache.Get(apiKey.Digest)
|
||||
is.False(ok)
|
||||
@@ -239,7 +240,8 @@ func Test_DeleteAPIKey(t *testing.T) {
|
||||
func Test_InvalidateUserKeyCache(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
service := NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
||||
@@ -247,10 +249,10 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
|
||||
// generate api keys
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey1, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKey2, err := service.GenerateApiKey(user, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
// verify api keys are present in cache
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)
|
||||
@@ -277,11 +279,11 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
|
||||
// generate keys for 2 users
|
||||
user1 := portainer.User{ID: 1}
|
||||
_, apiKey1, err := service.GenerateApiKey(user1, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
user2 := portainer.User{ID: 2}
|
||||
_, apiKey2, err := service.GenerateApiKey(user2, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
// verify keys in cache
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)
|
||||
|
||||
@@ -17,54 +17,20 @@ func TarFileInBuffer(fileContent []byte, fileName string, mode int64) ([]byte, e
|
||||
Size: int64(len(fileContent)),
|
||||
}
|
||||
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
err := tarWriter.WriteHeader(header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := tarWriter.Write(fileContent); err != nil {
|
||||
_, err = tarWriter.Write(fileContent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tarWriter.Close(); err != nil {
|
||||
err = tarWriter.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// tarFileInBuffer represents a tar archive buffer.
|
||||
type tarFileInBuffer struct {
|
||||
b *bytes.Buffer
|
||||
w *tar.Writer
|
||||
}
|
||||
|
||||
func NewTarFileInBuffer() *tarFileInBuffer {
|
||||
var b bytes.Buffer
|
||||
return &tarFileInBuffer{b: &b, w: tar.NewWriter(&b)}
|
||||
}
|
||||
|
||||
// Put puts a single file to tar archive buffer.
|
||||
func (t *tarFileInBuffer) Put(fileContent []byte, fileName string, mode int64) error {
|
||||
hdr := &tar.Header{
|
||||
Name: fileName,
|
||||
Mode: mode,
|
||||
Size: int64(len(fileContent)),
|
||||
}
|
||||
|
||||
if err := t.w.WriteHeader(hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := t.w.Write(fileContent)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Bytes returns the archive as a byte array.
|
||||
func (t *tarFileInBuffer) Bytes() []byte {
|
||||
return t.b.Bytes()
|
||||
}
|
||||
|
||||
func (t *tarFileInBuffer) Close() error {
|
||||
return t.w.Close()
|
||||
}
|
||||
|
||||
@@ -3,33 +3,28 @@ package archive
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
)
|
||||
|
||||
// TarGzDir creates a tar.gz archive and returns it's path.
|
||||
// abosolutePath should be an absolute path to a directory.
|
||||
// Archive name will be <directoryName>.tar.gz and will be placed next to the directory.
|
||||
func TarGzDir(absolutePath string) (string, error) {
|
||||
targzPath := filepath.Join(absolutePath, filepath.Base(absolutePath)+".tar.gz")
|
||||
targzPath := filepath.Join(absolutePath, fmt.Sprintf("%s.tar.gz", filepath.Base(absolutePath)))
|
||||
outFile, err := os.Create(targzPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer logs.CloseAndLogErr(outFile)
|
||||
defer outFile.Close()
|
||||
|
||||
zipWriter := gzip.NewWriter(outFile)
|
||||
defer logs.CloseAndLogErr(zipWriter)
|
||||
|
||||
defer zipWriter.Close()
|
||||
tarWriter := tar.NewWriter(zipWriter)
|
||||
defer logs.CloseAndLogErr(tarWriter)
|
||||
defer tarWriter.Close()
|
||||
|
||||
err = filepath.Walk(absolutePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
@@ -52,6 +47,18 @@ func TarGzDir(absolutePath string) (string, error) {
|
||||
}
|
||||
|
||||
func addToArchive(tarWriter *tar.Writer, pathInArchive string, path string, info os.FileInfo) error {
|
||||
header, err := tar.FileInfoHeader(info, info.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header.Name = pathInArchive // use relative paths in archive
|
||||
|
||||
err = tarWriter.WriteHeader(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
@@ -60,26 +67,6 @@ func addToArchive(tarWriter *tar.Writer, pathInArchive string, path string, info
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := tar.FileInfoHeader(stat, stat.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = pathInArchive // use relative paths in archive
|
||||
|
||||
err = tarWriter.WriteHeader(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stat.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = io.Copy(tarWriter, file)
|
||||
return err
|
||||
}
|
||||
@@ -90,14 +77,14 @@ func ExtractTarGz(r io.Reader, outputDirPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer logs.CloseAndLogErr(zipReader)
|
||||
defer zipReader.Close()
|
||||
|
||||
tarReader := tar.NewReader(zipReader)
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -109,8 +96,8 @@ func ExtractTarGz(r io.Reader, outputDirPath string) error {
|
||||
case tar.TypeDir:
|
||||
// skip, dir will be created with a file
|
||||
case tar.TypeReg:
|
||||
p := filesystem.JoinPaths(outputDirPath, header.Name)
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0o744); err != nil {
|
||||
p := filepath.Clean(filepath.Join(outputDirPath, header.Name))
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0744); err != nil {
|
||||
return fmt.Errorf("Failed to extract dir %s", filepath.Dir(p))
|
||||
}
|
||||
outFile, err := os.Create(p)
|
||||
@@ -120,9 +107,9 @@ func ExtractTarGz(r io.Reader, outputDirPath string) error {
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
return fmt.Errorf("Failed to extract file %s", header.Name)
|
||||
}
|
||||
logs.CloseAndLogErr(outFile)
|
||||
outFile.Close()
|
||||
default:
|
||||
return fmt.Errorf("tar: unknown type: %v in %s",
|
||||
return fmt.Errorf("Tar: uknown type: %v in %s",
|
||||
header.Typeflag,
|
||||
header.Name)
|
||||
}
|
||||
|
||||
@@ -1,61 +1,51 @@
|
||||
package archive
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/docker/docker/pkg/ioutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func listFiles(dir string) []string {
|
||||
items := make([]string, 0)
|
||||
|
||||
if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if path == dir {
|
||||
return nil
|
||||
}
|
||||
|
||||
items = append(items, path)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to list files in directory")
|
||||
}
|
||||
})
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func Test_shouldCreateArchive(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
func Test_shouldCreateArhive(t *testing.T) {
|
||||
tmpdir, _ := ioutils.TempDir("", "backup")
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
content := []byte("content")
|
||||
|
||||
err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
ioutil.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
ioutil.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
ioutil.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
|
||||
gzPath, err := TarGzDir(tmpdir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, fmt.Sprintf("%s.tar.gz", filepath.Base(tmpdir))), gzPath)
|
||||
|
||||
extractionDir, _ := ioutils.TempDir("", "extract")
|
||||
defer os.RemoveAll(extractionDir)
|
||||
|
||||
extractionDir := t.TempDir()
|
||||
cmd := exec.Command("tar", "-xzf", gzPath, "-C", extractionDir)
|
||||
if err := cmd.Run(); err != nil {
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
t.Fatal("Failed to extract archive: ", err)
|
||||
}
|
||||
extractedFiles := listFiles(extractionDir)
|
||||
@@ -63,8 +53,7 @@ func Test_shouldCreateArchive(t *testing.T) {
|
||||
wasExtracted := func(p string) {
|
||||
fullpath := path.Join(extractionDir, p)
|
||||
assert.Contains(t, extractedFiles, fullpath)
|
||||
copyContent, err := os.ReadFile(fullpath)
|
||||
require.NoError(t, err)
|
||||
copyContent, _ := ioutil.ReadFile(fullpath)
|
||||
assert.Equal(t, content, copyContent)
|
||||
}
|
||||
|
||||
@@ -73,29 +62,26 @@ func Test_shouldCreateArchive(t *testing.T) {
|
||||
wasExtracted("dir/.dotfile")
|
||||
}
|
||||
|
||||
func Test_shouldCreateArchive2(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
func Test_shouldCreateArhiveXXXXX(t *testing.T) {
|
||||
tmpdir, _ := ioutils.TempDir("", "backup")
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
content := []byte("content")
|
||||
|
||||
err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
ioutil.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
ioutil.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
ioutil.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
|
||||
gzPath, err := TarGzDir(tmpdir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, fmt.Sprintf("%s.tar.gz", filepath.Base(tmpdir))), gzPath)
|
||||
|
||||
extractionDir, _ := ioutils.TempDir("", "extract")
|
||||
defer os.RemoveAll(extractionDir)
|
||||
|
||||
extractionDir := t.TempDir()
|
||||
r, _ := os.Open(gzPath)
|
||||
if err := ExtractTarGz(r, extractionDir); err != nil {
|
||||
ExtractTarGz(r, extractionDir)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to extract archive: ", err)
|
||||
}
|
||||
extractedFiles := listFiles(extractionDir)
|
||||
@@ -103,7 +89,7 @@ func Test_shouldCreateArchive2(t *testing.T) {
|
||||
wasExtracted := func(p string) {
|
||||
fullpath := path.Join(extractionDir, p)
|
||||
assert.Contains(t, extractedFiles, fullpath)
|
||||
copyContent, _ := os.ReadFile(fullpath)
|
||||
copyContent, _ := ioutil.ReadFile(fullpath)
|
||||
assert.Equal(t, content, copyContent)
|
||||
}
|
||||
|
||||
@@ -111,56 +97,3 @@ func Test_shouldCreateArchive2(t *testing.T) {
|
||||
wasExtracted("dir/inner")
|
||||
wasExtracted("dir/.dotfile")
|
||||
}
|
||||
|
||||
func TestExtractTarGzPathTraversal(t *testing.T) {
|
||||
testDir := t.TempDir()
|
||||
|
||||
// Create an evil file with a path traversal attempt
|
||||
tarPath := filesystem.JoinPaths(testDir, "evil.tar.gz")
|
||||
|
||||
evilFile, err := os.Create(tarPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
gzWriter := gzip.NewWriter(evilFile)
|
||||
tarWriter := tar.NewWriter(gzWriter)
|
||||
|
||||
content := []byte("evil content")
|
||||
|
||||
header := &tar.Header{
|
||||
Name: "../evil.txt",
|
||||
Mode: 0600,
|
||||
Size: int64(len(content)),
|
||||
Typeflag: tar.TypeReg,
|
||||
}
|
||||
|
||||
err = tarWriter.WriteHeader(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tarWriter.Write(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tarWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = evilFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to extract the evil file
|
||||
extractionDir := filesystem.JoinPaths(testDir, "extraction")
|
||||
err = os.Mkdir(extractionDir, 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
tarFile, err := os.Open(tarPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the file didn't escape
|
||||
err = ExtractTarGz(tarFile, extractionDir)
|
||||
require.NoError(t, err)
|
||||
require.NoFileExists(t, filesystem.JoinPaths(testDir, "evil.txt"))
|
||||
|
||||
err = tarFile.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -2,17 +2,60 @@ package archive
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// UnzipArchive will unzip an archive from bytes into the dest destination folder on disk
|
||||
func UnzipArchive(archiveData []byte, dest string) error {
|
||||
zipReader, err := zip.NewReader(bytes.NewReader(archiveData), int64(len(archiveData)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, zipFile := range zipReader.File {
|
||||
err := extractFileFromArchive(zipFile, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractFileFromArchive(file *zip.File, dest string) error {
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fpath := filepath.Join(dest, file.Name)
|
||||
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(outFile, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return outFile.Close()
|
||||
}
|
||||
|
||||
// UnzipFile will decompress a zip archive, moving all files and folders
|
||||
// within the zip file (parameter 1) to an output directory (parameter 2).
|
||||
func UnzipFile(src string, dest string) error {
|
||||
@@ -20,7 +63,7 @@ func UnzipFile(src string, dest string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer logs.CloseAndLogErr(r)
|
||||
defer r.Close()
|
||||
|
||||
for _, f := range r.File {
|
||||
p := filepath.Join(dest, f.Name)
|
||||
@@ -32,14 +75,12 @@ func UnzipFile(src string, dest string) error {
|
||||
|
||||
if f.FileInfo().IsDir() {
|
||||
// Make Folder
|
||||
if err := os.MkdirAll(p, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
os.MkdirAll(p, os.ModePerm)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := unzipFile(f, p); err != nil {
|
||||
err = unzipFile(f, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -52,20 +93,20 @@ func unzipFile(f *zip.File, p string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(p), os.ModePerm); err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't make a path %s", p)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't create file %s", p)
|
||||
}
|
||||
defer logs.CloseAndLogErr(outFile)
|
||||
|
||||
defer outFile.Close()
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't open zip file %s in the archive", f.Name)
|
||||
}
|
||||
defer logs.CloseAndLogErr(rc)
|
||||
defer rc.Close()
|
||||
|
||||
if _, err = io.Copy(outFile, rc); err != nil {
|
||||
_, err = io.Copy(outFile, rc)
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't copy an archived file content")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package archive
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUnzipFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dir, err := ioutil.TempDir("", "unzip-test-")
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
/*
|
||||
Archive structure.
|
||||
├── 0
|
||||
@@ -19,9 +21,9 @@ func TestUnzipFile(t *testing.T) {
|
||||
└── 0.txt
|
||||
*/
|
||||
|
||||
err := UnzipFile("./testdata/sample_archive.zip", dir)
|
||||
err = UnzipFile("./testdata/sample_archive.zip", dir)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
archiveDir := dir + "/sample_archive"
|
||||
assert.FileExists(t, filepath.Join(archiveDir, "0.txt"))
|
||||
assert.FileExists(t, filepath.Join(archiveDir, "0", "1.txt"))
|
||||
|
||||
@@ -3,7 +3,7 @@ package ecr
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -15,7 +15,7 @@ func (s *Service) GetEncodedAuthorizationToken() (token *string, expiry *time.Ti
|
||||
}
|
||||
|
||||
if len(getAuthorizationTokenOutput.AuthorizationData) == 0 {
|
||||
err = errors.New("AuthorizationData is empty")
|
||||
err = fmt.Errorf("AuthorizationData is empty")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (s *Service) ParseAuthorizationToken(token string) (username string, passwo
|
||||
|
||||
splitToken := strings.Split(token, ":")
|
||||
if len(splitToken) < 2 {
|
||||
err = errors.New("invalid ECR authorization token")
|
||||
err = fmt.Errorf("invalid ECR authorization token")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -6,15 +6,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/ecr"
|
||||
)
|
||||
|
||||
// Registry represents an ECR registry endpoint information.
|
||||
// This struct is used to parse and validate ECR endpoint URLs.
|
||||
type Registry struct {
|
||||
ID string // AWS account ID (empty for accountless endpoints like "ecr-fips.us-west-1.amazonaws.com")
|
||||
FIPS bool // Whether this is a FIPS endpoint (contains "-fips" in the URL)
|
||||
Region string // AWS region (e.g., "us-east-1", "us-gov-west-1")
|
||||
Public bool // Whether this is ecr-public.aws.com
|
||||
}
|
||||
|
||||
type (
|
||||
Service struct {
|
||||
accessKey string
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
package ecr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ecrEndpointPattern matches all valid ECR endpoints including account-prefixed and accountless formats.
|
||||
// Based on AWS ECR credential helper regex but extended to support accountless endpoints.
|
||||
//
|
||||
// Supported formats:
|
||||
// - Account-prefixed: 123456789012.dkr.ecr-fips.us-east-1.amazonaws.com
|
||||
// - Account-prefixed (hyphen): 123456789012.dkr-ecr-fips.us-west-1.on.aws
|
||||
// - Accountless service: ecr-fips.us-west-1.amazonaws.com
|
||||
// - Accountless API: ecr-fips.us-east-1.api.aws
|
||||
// - Non-FIPS variants: All formats above without "-fips"
|
||||
//
|
||||
// Regex groups:
|
||||
// - Group 1: Full account prefix (optional) - e.g., "123456789012.dkr." or "123456789012.dkr-"
|
||||
// - Group 2: Account ID (optional) - e.g., "123456789012"
|
||||
// - Group 3: FIPS flag (optional) - either "-fips" or empty string
|
||||
// - Group 4: Region - e.g., "us-east-1", "us-gov-west-1"
|
||||
// - Group 5: Domain suffix - e.g., "amazonaws.com", "api.aws"
|
||||
var ecrEndpointPattern = regexp.MustCompile(
|
||||
`^((\d{12})\.dkr[\.\-])?ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.(amazonaws\.(?:com(?:\.cn)?|eu)|api\.aws|on\.(?:aws|amazonwebservices\.com\.cn)|sc2s\.sgov\.gov|c2s\.ic\.gov|cloud\.adc-e\.uk|csp\.hci\.ic\.gov)$`,
|
||||
)
|
||||
|
||||
// ParseECREndpoint parses an ECR registry URL and extracts registry information.
|
||||
|
||||
// This function replaces the AWS ECR credential helper library's ExtractRegistry function,
|
||||
// which only supports account-prefixed endpoints.
|
||||
//
|
||||
// Reference: https://docs.aws.amazon.com/general/latest/gr/ecr.html
|
||||
func ParseECREndpoint(urlStr string) (*Registry, error) {
|
||||
// Normalize URL by adding https:// prefix if not present
|
||||
if !strings.HasPrefix(urlStr, "https://") && !strings.HasPrefix(urlStr, "http://") {
|
||||
urlStr = "https://" + urlStr
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
hostname := u.Hostname()
|
||||
|
||||
// Special case: ECR Public
|
||||
// ECR Public uses a different domain and doesn't have FIPS variant
|
||||
if hostname == "ecr-public.aws.com" {
|
||||
return &Registry{
|
||||
FIPS: false,
|
||||
Public: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Parse standard ECR endpoints using regex
|
||||
matches := ecrEndpointPattern.FindStringSubmatch(hostname)
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("not a valid ECR endpoint: %s", hostname)
|
||||
}
|
||||
|
||||
return &Registry{
|
||||
ID: matches[2], // Account ID (may be empty for accountless endpoints)
|
||||
FIPS: matches[3] == "-fips", // Check if "-fips" is present
|
||||
Region: matches[4], // AWS region
|
||||
Public: false,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
package ecr
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseECREndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
want *Registry
|
||||
wantError bool
|
||||
}{
|
||||
// Standard AWS Commercial - Account-prefixed FIPS
|
||||
{
|
||||
name: "account-prefixed FIPS us-east-1",
|
||||
url: "123456789012.dkr.ecr-fips.us-east-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "account-prefixed FIPS us-west-2",
|
||||
url: "123456789012.dkr.ecr-fips.us-west-2.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-west-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// Accountless FIPS service endpoints
|
||||
{
|
||||
name: "accountless FIPS us-west-1",
|
||||
url: "ecr-fips.us-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS us-east-2",
|
||||
url: "ecr-fips.us-east-2.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-east-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// Accountless FIPS API endpoints
|
||||
{
|
||||
name: "accountless FIPS API us-west-1",
|
||||
url: "ecr-fips.us-west-1.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS API us-east-1",
|
||||
url: "ecr-fips.us-east-1.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// on.aws domain with hyphen separator
|
||||
{
|
||||
name: "account-prefixed FIPS hyphen us-west-1",
|
||||
url: "123456789012.dkr-ecr-fips.us-west-1.on.aws",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "account-prefixed FIPS hyphen us-east-2",
|
||||
url: "123456789012.dkr-ecr-fips.us-east-2.on.aws",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-east-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// AWS GovCloud
|
||||
{
|
||||
name: "account-prefixed FIPS us-gov-east-1",
|
||||
url: "123456789012.dkr.ecr-fips.us-gov-east-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-gov-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "account-prefixed FIPS us-gov-west-1",
|
||||
url: "123456789012.dkr.ecr-fips.us-gov-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-gov-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS us-gov-west-1",
|
||||
url: "ecr-fips.us-gov-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-gov-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS API us-gov-east-1",
|
||||
url: "ecr-fips.us-gov-east-1.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-gov-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// ECR Public
|
||||
{
|
||||
name: "ecr-public",
|
||||
url: "ecr-public.aws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: false,
|
||||
Region: "",
|
||||
Public: true,
|
||||
},
|
||||
},
|
||||
|
||||
// Non-FIPS endpoints (valid ECR but FIPS=false)
|
||||
{
|
||||
name: "account-prefixed non-FIPS us-east-1",
|
||||
url: "123456789012.dkr.ecr.us-east-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: false,
|
||||
Region: "us-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless non-FIPS us-west-1",
|
||||
url: "ecr.us-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: false,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless non-FIPS API us-east-2",
|
||||
url: "ecr.us-east-2.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: false,
|
||||
Region: "us-east-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// URLs with https:// prefix
|
||||
{
|
||||
name: "with https prefix",
|
||||
url: "https://ecr-fips.us-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// Invalid endpoints
|
||||
{
|
||||
name: "not an ECR URL",
|
||||
url: "not-an-ecr-url.com",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid account ID length",
|
||||
url: "123.dkr.ecr-fips.us-east-1.amazonaws.com",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
url: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "docker hub",
|
||||
url: "docker.io",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseECREndpoint(tt.url)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("ParseECREndpoint() expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("ParseECREndpoint() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if got.ID != tt.want.ID {
|
||||
t.Errorf("ParseECREndpoint() ID = %v, want %v", got.ID, tt.want.ID)
|
||||
}
|
||||
if got.FIPS != tt.want.FIPS {
|
||||
t.Errorf("ParseECREndpoint() FIPS = %v, want %v", got.FIPS, tt.want.FIPS)
|
||||
}
|
||||
if got.Region != tt.want.Region {
|
||||
t.Errorf("ParseECREndpoint() Region = %v, want %v", got.Region, tt.want.Region)
|
||||
}
|
||||
if got.Public != tt.want.Public {
|
||||
t.Errorf("ParseECREndpoint() Public = %v, want %v", got.Public, tt.want.Public)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,22 +7,19 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/portainer/portainer/api/archive"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/api/http/offlinegate"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const rwxr__r__ os.FileMode = 0o744
|
||||
const rwxr__r__ os.FileMode = 0744
|
||||
|
||||
var filesToBackup = []string{
|
||||
"certs",
|
||||
"chisel",
|
||||
"compose",
|
||||
"config.json",
|
||||
"custom_templates",
|
||||
@@ -36,9 +33,35 @@ var filesToBackup = []string{
|
||||
|
||||
// Creates a tar.gz system archive and encrypts it if password is not empty. Returns a path to the archive file.
|
||||
func CreateBackupArchive(password string, gate *offlinegate.OfflineGate, datastore dataservices.DataStore, filestorePath string) (string, error) {
|
||||
backupDirPath, err := backupDatabaseAndFilesystem(gate, datastore, filestorePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
unlock := gate.Lock()
|
||||
defer unlock()
|
||||
|
||||
backupDirPath := filepath.Join(filestorePath, "backup", time.Now().Format("2006-01-02_15-04-05"))
|
||||
if err := os.MkdirAll(backupDirPath, rwxr__r__); err != nil {
|
||||
return "", errors.Wrap(err, "Failed to create backup dir")
|
||||
}
|
||||
|
||||
{
|
||||
// new export
|
||||
exportFilename := path.Join(backupDirPath, fmt.Sprintf("export-%d.json", time.Now().Unix()))
|
||||
|
||||
err := datastore.Export(exportFilename)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Debugf("failed to export to %s", exportFilename)
|
||||
} else {
|
||||
logrus.Debugf("exported to %s", exportFilename)
|
||||
}
|
||||
}
|
||||
|
||||
if err := backupDb(backupDirPath, datastore); err != nil {
|
||||
return "", errors.Wrap(err, "Failed to backup database")
|
||||
}
|
||||
|
||||
for _, filename := range filesToBackup {
|
||||
err := filesystem.CopyPath(filepath.Join(filestorePath, filename), backupDirPath)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "Failed to create backup file")
|
||||
}
|
||||
}
|
||||
|
||||
archivePath, err := archive.TarGzDir(backupDirPath)
|
||||
@@ -56,41 +79,15 @@ func CreateBackupArchive(password string, gate *offlinegate.OfflineGate, datasto
|
||||
return archivePath, nil
|
||||
}
|
||||
|
||||
func backupDatabaseAndFilesystem(gate *offlinegate.OfflineGate, datastore dataservices.DataStore, filestorePath string) (string, error) {
|
||||
unlock := gate.Lock()
|
||||
defer unlock()
|
||||
|
||||
backupDirPath := filepath.Join(filestorePath, "backup", time.Now().Format("2006-01-02_15-04-05"))
|
||||
if err := os.MkdirAll(backupDirPath, rwxr__r__); err != nil {
|
||||
return "", errors.Wrap(err, "Failed to create backup dir")
|
||||
}
|
||||
|
||||
// new export
|
||||
exportFilename := path.Join(backupDirPath, fmt.Sprintf("export-%d.json", time.Now().Unix()))
|
||||
|
||||
if err := datastore.Export(exportFilename); err != nil {
|
||||
log.Error().Err(err).Str("filename", exportFilename).Msg("failed to export")
|
||||
} else {
|
||||
log.Debug().Str("filename", exportFilename).Msg("file exported")
|
||||
}
|
||||
|
||||
if err := backupDb(backupDirPath, datastore); err != nil {
|
||||
return "", errors.Wrap(err, "Failed to backup database")
|
||||
}
|
||||
|
||||
for _, filename := range filesToBackup {
|
||||
if err := filesystem.CopyPath(filepath.Join(filestorePath, filename), backupDirPath); err != nil {
|
||||
return "", errors.Wrap(err, "Failed to create backup file")
|
||||
}
|
||||
}
|
||||
|
||||
return backupDirPath, nil
|
||||
}
|
||||
|
||||
func backupDb(backupDirPath string, datastore dataservices.DataStore) error {
|
||||
dbFileName := datastore.Connection().GetDatabaseFileName()
|
||||
_, err := datastore.Backup(filepath.Join(backupDirPath, dbFileName))
|
||||
return err
|
||||
backupWriter, err := os.Create(filepath.Join(backupDirPath, "portainer.db"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = datastore.BackupTo(backupWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
return backupWriter.Close()
|
||||
}
|
||||
|
||||
func encrypt(path string, passphrase string) (string, error) {
|
||||
@@ -98,13 +95,15 @@ func encrypt(path string, passphrase string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer logs.CloseAndLogErr(in)
|
||||
defer in.Close()
|
||||
|
||||
outFileName := path + ".encrypted"
|
||||
outFileName := fmt.Sprintf("%s.encrypted", path)
|
||||
out, err := os.Create(outFileName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return outFileName, crypto.AesEncrypt(in, out, []byte(passphrase))
|
||||
err = crypto.AesEncrypt(in, out, []byte(passphrase))
|
||||
|
||||
return outFileName, err
|
||||
}
|
||||
|
||||
@@ -3,10 +3,8 @@ package backup
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -16,8 +14,6 @@ import (
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/api/http/offlinegate"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var filesToRestore = append(filesToBackup, "portainer.db")
|
||||
@@ -28,35 +24,26 @@ func RestoreArchive(archive io.Reader, password string, filestorePath string, ga
|
||||
if password != "" {
|
||||
archive, err = decrypt(archive, password)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to decrypt the archive. Please ensure the password is correct and try again")
|
||||
return errors.Wrap(err, "failed to decrypt the archive")
|
||||
}
|
||||
}
|
||||
|
||||
restorePath := filepath.Join(filestorePath, "restore", time.Now().Format("20060102150405"))
|
||||
defer func() {
|
||||
if err := os.RemoveAll(filepath.Dir(restorePath)); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to clean up restore files")
|
||||
}
|
||||
}()
|
||||
defer os.RemoveAll(filepath.Dir(restorePath))
|
||||
|
||||
if err := extractArchive(archive, restorePath); err != nil {
|
||||
err = extractArchive(archive, restorePath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot extract files from the archive. Please ensure the password is correct and try again")
|
||||
}
|
||||
|
||||
unlock := gate.Lock()
|
||||
defer unlock()
|
||||
|
||||
if err := datastore.Close(); err != nil {
|
||||
if err = datastore.Close(); err != nil {
|
||||
return errors.Wrap(err, "Failed to stop db")
|
||||
}
|
||||
|
||||
// At some point, backups were created containing a subdirectory, now we need to handle both
|
||||
restorePath, err = getRestoreSourcePath(restorePath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to restore from backup. Portainer database missing from backup file")
|
||||
}
|
||||
|
||||
if err := restoreFiles(restorePath, filestorePath); err != nil {
|
||||
if err = restoreFiles(restorePath, filestorePath); err != nil {
|
||||
return errors.Wrap(err, "failed to restore the system state")
|
||||
}
|
||||
|
||||
@@ -72,29 +59,10 @@ func extractArchive(r io.Reader, destinationDirPath string) error {
|
||||
return archive.ExtractTarGz(r, destinationDirPath)
|
||||
}
|
||||
|
||||
func getRestoreSourcePath(dir string) (string, error) {
|
||||
// find portainer.db or portainer.edb file. Return the parent directory
|
||||
var portainerdbRegex = regexp.MustCompile(`^portainer.e?db$`)
|
||||
|
||||
backupDirPath := dir
|
||||
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if portainerdbRegex.MatchString(d.Name()) {
|
||||
backupDirPath = filepath.Dir(path)
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return backupDirPath, err
|
||||
}
|
||||
|
||||
func restoreFiles(srcDir string, destinationDir string) error {
|
||||
for _, filename := range filesToRestore {
|
||||
if err := filesystem.CopyPath(filepath.Join(srcDir, filename), destinationDir); err != nil {
|
||||
err := filesystem.CopyPath(filepath.Join(srcDir, filename), destinationDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -102,18 +70,14 @@ func restoreFiles(srcDir string, destinationDir string) error {
|
||||
// TODO: This is very boltdb module specific once again due to the filename. Move to bolt module? Refactor for another day
|
||||
|
||||
// Prevent the possibility of having both databases. Remove any default new instance
|
||||
if err := os.Remove(filepath.Join(destinationDir, boltdb.DatabaseFileName)); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Remove(filepath.Join(destinationDir, boltdb.EncryptedDatabaseFileName)); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
os.Remove(filepath.Join(destinationDir, boltdb.DatabaseFileName))
|
||||
os.Remove(filepath.Join(destinationDir, boltdb.EncryptedDatabaseFileName))
|
||||
|
||||
// Now copy the database. It'll be either portainer.db or portainer.edb
|
||||
|
||||
// Note: CopyPath does not return an error if the source file doesn't exist
|
||||
if err := filesystem.CopyPath(filepath.Join(srcDir, boltdb.EncryptedDatabaseFileName), destinationDir); err != nil {
|
||||
err := filesystem.CopyPath(filepath.Join(srcDir, boltdb.EncryptedDatabaseFileName), destinationDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
|
||||
chshare "github.com/jpillora/chisel/share"
|
||||
)
|
||||
|
||||
var one = new(big.Int).SetInt64(1)
|
||||
|
||||
// GenerateGo119CompatibleKey This function is basically copied from chshare.GenerateKey.
|
||||
func GenerateGo119CompatibleKey(seed string) ([]byte, error) {
|
||||
r := chshare.NewDetermRand([]byte(seed))
|
||||
priv, err := ecdsaGenerateKey(elliptic.P256(), r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to marshal ECDSA private key: %w", err)
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}), nil
|
||||
}
|
||||
|
||||
// This function is copied from Go1.19
|
||||
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
|
||||
params := c.Params()
|
||||
// Note that for P-521 this will actually be 63 bits more than the order, as
|
||||
// division rounds down, but the extra bit is inconsequential.
|
||||
b := make([]byte, params.N.BitLen()/8+8)
|
||||
_, err = io.ReadFull(rand, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
k = new(big.Int).SetBytes(b)
|
||||
n := new(big.Int).Sub(params.N, one)
|
||||
k.Mod(k, n)
|
||||
k.Add(k, one)
|
||||
return
|
||||
}
|
||||
|
||||
// This function is copied from Go1.19
|
||||
func ecdsaGenerateKey(c elliptic.Curve, rand io.Reader) (*ecdsa.PrivateKey, error) {
|
||||
k, err := randFieldElement(c, rand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
priv := new(ecdsa.PrivateKey)
|
||||
priv.Curve = c
|
||||
priv.D = k
|
||||
priv.X, priv.Y = c.ScalarBaseMult(k.Bytes())
|
||||
return priv, nil
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateGo119CompatibleKey(t *testing.T) {
|
||||
type args struct {
|
||||
seed string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Generate Go 1.19 compatible private key with a given seed",
|
||||
args: args{seed: "94qh17MCIk8BOkiI"},
|
||||
want: []byte("-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIHeohwk0Gy3RHVVViaHz7pz/HOiqA7fkv1FTM3mGgfT3oAoGCCqGSM49\nAwEHoUQDQgAEN7riX06xDsLNPuUmOvYFluNEakcFwZZRVvOcIYk/9VYnanDzW0Km\n8/BUUiKyJDuuGdS4fj9SlQ4iL8yBK01uKg==\n-----END EC PRIVATE KEY-----\n"),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := GenerateGo119CompatibleKey(tt.args.seed)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateGo119CompatibleKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GenerateGo119CompatibleKey()\ngot: Z %v\nwant: %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
47
api/chisel/schedules.go
Normal file
47
api/chisel/schedules.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package chisel
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
// AddEdgeJob register an EdgeJob inside the tunnel details associated to an environment(endpoint).
|
||||
func (service *Service) AddEdgeJob(endpointID portainer.EndpointID, edgeJob *portainer.EdgeJob) {
|
||||
service.mu.Lock()
|
||||
tunnel := service.getTunnelDetails(endpointID)
|
||||
|
||||
existingJobIndex := -1
|
||||
for idx, existingJob := range tunnel.Jobs {
|
||||
if existingJob.ID == edgeJob.ID {
|
||||
existingJobIndex = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if existingJobIndex == -1 {
|
||||
tunnel.Jobs = append(tunnel.Jobs, *edgeJob)
|
||||
} else {
|
||||
tunnel.Jobs[existingJobIndex] = *edgeJob
|
||||
}
|
||||
|
||||
service.mu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveEdgeJob will remove the specified Edge job from each tunnel it was registered with.
|
||||
func (service *Service) RemoveEdgeJob(edgeJobID portainer.EdgeJobID) {
|
||||
service.mu.Lock()
|
||||
|
||||
for _, tunnel := range service.tunnelDetailsMap {
|
||||
// Filter in-place
|
||||
n := 0
|
||||
for _, edgeJob := range tunnel.Jobs {
|
||||
if edgeJob.ID != edgeJobID {
|
||||
tunnel.Jobs[n] = edgeJob
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
tunnel.Jobs = tunnel.Jobs[:n]
|
||||
}
|
||||
|
||||
service.mu.Unlock()
|
||||
}
|
||||
@@ -3,141 +3,91 @@ package chisel
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
chserver "github.com/jpillora/chisel/server"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/http/proxy"
|
||||
|
||||
chserver "github.com/jpillora/chisel/server"
|
||||
"github.com/jpillora/chisel/share/ccrypto"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
tunnelCleanupInterval = 10 * time.Second
|
||||
requiredTimeout = 15 * time.Second
|
||||
activeTimeout = 4*time.Minute + 30*time.Second
|
||||
pingTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// Service represents a service to manage the state of multiple reverse tunnels.
|
||||
// It is used to start a reverse tunnel server and to manage the connection status of each tunnel
|
||||
// connected to the tunnel server.
|
||||
type Service struct {
|
||||
serverFingerprint string
|
||||
serverPort string
|
||||
activeTunnels map[portainer.EndpointID]*portainer.TunnelDetails
|
||||
edgeJobs map[portainer.EndpointID][]portainer.EdgeJob
|
||||
dataStore dataservices.DataStore
|
||||
snapshotService portainer.SnapshotService
|
||||
chiselServer *chserver.Server
|
||||
shutdownCtx context.Context
|
||||
ProxyManager *proxy.Manager
|
||||
mu sync.RWMutex
|
||||
fileService portainer.FileService
|
||||
defaultCheckinInterval int
|
||||
serverFingerprint string
|
||||
serverPort string
|
||||
tunnelDetailsMap map[portainer.EndpointID]*portainer.TunnelDetails
|
||||
dataStore dataservices.DataStore
|
||||
snapshotService portainer.SnapshotService
|
||||
chiselServer *chserver.Server
|
||||
shutdownCtx context.Context
|
||||
ProxyManager *proxy.Manager
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewService returns a pointer to a new instance of Service
|
||||
func NewService(dataStore dataservices.DataStore, shutdownCtx context.Context, fileService portainer.FileService) *Service {
|
||||
defaultCheckinInterval := portainer.DefaultEdgeAgentCheckinIntervalInSeconds
|
||||
|
||||
settings, err := dataStore.Settings().Settings()
|
||||
if err == nil {
|
||||
defaultCheckinInterval = settings.EdgeAgentCheckinInterval
|
||||
} else {
|
||||
log.Error().Err(err).Msg("unable to retrieve the settings from the database")
|
||||
}
|
||||
|
||||
func NewService(dataStore dataservices.DataStore, shutdownCtx context.Context) *Service {
|
||||
return &Service{
|
||||
activeTunnels: make(map[portainer.EndpointID]*portainer.TunnelDetails),
|
||||
edgeJobs: make(map[portainer.EndpointID][]portainer.EdgeJob),
|
||||
dataStore: dataStore,
|
||||
shutdownCtx: shutdownCtx,
|
||||
fileService: fileService,
|
||||
defaultCheckinInterval: defaultCheckinInterval,
|
||||
tunnelDetailsMap: make(map[portainer.EndpointID]*portainer.TunnelDetails),
|
||||
dataStore: dataStore,
|
||||
shutdownCtx: shutdownCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// pingAgent ping the given agent so that the agent can keep the tunnel alive
|
||||
func (service *Service) pingAgent(endpointID portainer.EndpointID) error {
|
||||
endpoint, err := service.dataStore.Endpoint().Endpoint(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tunnelAddr, err := service.TunnelAddr(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestURL := fmt.Sprintf("http://%s/ping", tunnelAddr)
|
||||
tunnel := service.GetTunnelDetails(endpointID)
|
||||
requestURL := fmt.Sprintf("http://127.0.0.1:%d/ping", tunnel.Port)
|
||||
req, err := http.NewRequest(http.MethodHead, requestURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: pingTimeout,
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
return resp.Body.Close()
|
||||
_, err = httpClient.Do(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// KeepTunnelAlive keeps the tunnel of the given environment for maxAlive duration, or until ctx is done
|
||||
func (service *Service) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxAlive time.Duration) {
|
||||
go service.keepTunnelAlive(endpointID, ctx, maxAlive)
|
||||
}
|
||||
go func() {
|
||||
log.Printf("[DEBUG] [chisel,KeepTunnelAlive] [endpoint_id: %d] [message: start for %.0f minutes]\n", endpointID, maxAlive.Minutes())
|
||||
maxAliveTicker := time.NewTicker(maxAlive)
|
||||
defer maxAliveTicker.Stop()
|
||||
pingTicker := time.NewTicker(tunnelCleanupInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
func (service *Service) keepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxAlive time.Duration) {
|
||||
log.Debug().
|
||||
Int("endpoint_id", int(endpointID)).
|
||||
Float64("max_alive_minutes", maxAlive.Minutes()).
|
||||
Msg("KeepTunnelAlive: start")
|
||||
|
||||
maxAliveTicker := time.NewTicker(maxAlive)
|
||||
defer maxAliveTicker.Stop()
|
||||
|
||||
pingTicker := time.NewTicker(tunnelCleanupInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
service.UpdateLastActivity(endpointID)
|
||||
|
||||
if err := service.pingAgent(endpointID); err != nil {
|
||||
log.Debug().
|
||||
Int("endpoint_id", int(endpointID)).
|
||||
Err(err).
|
||||
Msg("KeepTunnelAlive: ping agent")
|
||||
for {
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
service.SetTunnelStatusToActive(endpointID)
|
||||
err := service.pingAgent(endpointID)
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG] [chisel,KeepTunnelAlive] [endpoint_id: %d] [warning: ping agent err=%s]\n", endpointID, err)
|
||||
}
|
||||
case <-maxAliveTicker.C:
|
||||
log.Printf("[DEBUG] [chisel,KeepTunnelAlive] [endpoint_id: %d] [message: stop as %.0f minutes timeout]\n", endpointID, maxAlive.Minutes())
|
||||
return
|
||||
case <-ctx.Done():
|
||||
err := ctx.Err()
|
||||
log.Printf("[DEBUG] [chisel,KeepTunnelAlive] [endpoint_id: %d] [message: stop as err=%s]\n", endpointID, err)
|
||||
return
|
||||
}
|
||||
case <-maxAliveTicker.C:
|
||||
log.Debug().
|
||||
Int("endpoint_id", int(endpointID)).
|
||||
Float64("timeout_minutes", maxAlive.Minutes()).
|
||||
Msg("KeepTunnelAlive: tunnel keep alive timeout")
|
||||
|
||||
return
|
||||
case <-ctx.Done():
|
||||
err := ctx.Err()
|
||||
log.Debug().
|
||||
Int("endpoint_id", int(endpointID)).
|
||||
Err(err).
|
||||
Msg("KeepTunnelAlive: tunnel stop")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StartTunnelServer starts a tunnel server on the specified addr and port.
|
||||
@@ -146,14 +96,14 @@ func (service *Service) keepTunnelAlive(endpointID portainer.EndpointID, ctx con
|
||||
// It starts the tunnel status verification process in the background.
|
||||
// The snapshotter is used in the tunnel status verification process.
|
||||
func (service *Service) StartTunnelServer(addr, port string, snapshotService portainer.SnapshotService) error {
|
||||
privateKeyFile, err := service.retrievePrivateKeyFile()
|
||||
keySeed, err := service.retrievePrivateKeySeed()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := &chserver.Config{
|
||||
Reverse: true,
|
||||
KeyFile: privateKeyFile,
|
||||
KeySeed: keySeed,
|
||||
}
|
||||
|
||||
chiselServer, err := chserver.NewServer(config)
|
||||
@@ -164,21 +114,21 @@ func (service *Service) StartTunnelServer(addr, port string, snapshotService por
|
||||
service.serverFingerprint = chiselServer.GetFingerprint()
|
||||
service.serverPort = port
|
||||
|
||||
if err := chiselServer.Start(addr, port); err != nil {
|
||||
err = chiselServer.Start(addr, port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
service.chiselServer = chiselServer
|
||||
|
||||
// TODO: work-around Chisel default behavior.
|
||||
// By default, Chisel will allow anyone to connect if no user exists.
|
||||
username, password := generateRandomCredentials()
|
||||
if err = service.chiselServer.AddUser(username, password, "127.0.0.1"); err != nil {
|
||||
err = service.chiselServer.AddUser(username, password, "127.0.0.1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
service.snapshotService = snapshotService
|
||||
|
||||
go service.startTunnelVerificationLoop()
|
||||
|
||||
return nil
|
||||
@@ -189,50 +139,30 @@ func (service *Service) StopTunnelServer() error {
|
||||
return service.chiselServer.Close()
|
||||
}
|
||||
|
||||
func (service *Service) retrievePrivateKeyFile() (string, error) {
|
||||
privateKeyFile := service.fileService.GetDefaultChiselPrivateKeyPath()
|
||||
func (service *Service) retrievePrivateKeySeed() (string, error) {
|
||||
var serverInfo *portainer.TunnelServerInfo
|
||||
|
||||
if exists, _ := service.fileService.FileExists(privateKeyFile); exists {
|
||||
log.Info().
|
||||
Str("private-key", privateKeyFile).
|
||||
Msg("found Chisel private key file on disk")
|
||||
serverInfo, err := service.dataStore.TunnelServer().Info()
|
||||
if service.dataStore.IsErrObjectNotFound(err) {
|
||||
keySeed := uniuri.NewLen(16)
|
||||
|
||||
return privateKeyFile, nil
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("private-key", privateKeyFile).
|
||||
Msg("chisel private key file does not exist")
|
||||
|
||||
privateKey, err := ccrypto.GenerateKey("")
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Msg("failed to generate chisel private key")
|
||||
serverInfo = &portainer.TunnelServerInfo{
|
||||
PrivateKeySeed: keySeed,
|
||||
}
|
||||
|
||||
err := service.dataStore.TunnelServer().UpdateInfo(serverInfo)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err = service.fileService.StoreChiselPrivateKey(privateKey); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Msg("failed to save Chisel private key to disk")
|
||||
|
||||
return "", err
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("private-key", privateKeyFile).
|
||||
Msg("generated a new Chisel private key file")
|
||||
|
||||
return privateKeyFile, nil
|
||||
return serverInfo.PrivateKeySeed, nil
|
||||
}
|
||||
|
||||
func (service *Service) startTunnelVerificationLoop() {
|
||||
log.Debug().
|
||||
Float64("check_interval_seconds", tunnelCleanupInterval.Seconds()).
|
||||
Msg("starting tunnel management process")
|
||||
|
||||
log.Printf("[DEBUG] [chisel, monitoring] [check_interval_seconds: %f] [message: starting tunnel management process]", tunnelCleanupInterval.Seconds())
|
||||
ticker := time.NewTicker(tunnelCleanupInterval)
|
||||
|
||||
for {
|
||||
@@ -240,57 +170,52 @@ func (service *Service) startTunnelVerificationLoop() {
|
||||
case <-ticker.C:
|
||||
service.checkTunnels()
|
||||
case <-service.shutdownCtx.Done():
|
||||
log.Debug().Msg("shutting down tunnel service")
|
||||
|
||||
log.Println("[DEBUG] Shutting down tunnel service")
|
||||
if err := service.StopTunnelServer(); err != nil {
|
||||
log.Debug().Err(err).Msg("stopped tunnel service")
|
||||
log.Printf("Stopped tunnel service: %s", err)
|
||||
}
|
||||
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkTunnels finds the first tunnel that has not had any activity recently
|
||||
// and attempts to take a snapshot, then closes it and returns
|
||||
func (service *Service) checkTunnels() {
|
||||
service.mu.RLock()
|
||||
tunnels := make(map[portainer.EndpointID]portainer.TunnelDetails)
|
||||
|
||||
for endpointID, tunnel := range service.activeTunnels {
|
||||
elapsed := time.Since(tunnel.LastActivity)
|
||||
log.Debug().
|
||||
Int("endpoint_id", int(endpointID)).
|
||||
Float64("last_activity_seconds", elapsed.Seconds()).
|
||||
Msg("environment tunnel monitoring")
|
||||
service.mu.Lock()
|
||||
for key, tunnel := range service.tunnelDetailsMap {
|
||||
tunnels[key] = *tunnel
|
||||
}
|
||||
service.mu.Unlock()
|
||||
|
||||
if tunnel.Status == portainer.EdgeAgentManagementRequired && elapsed < activeTimeout {
|
||||
for endpointID, tunnel := range tunnels {
|
||||
if tunnel.LastActivity.IsZero() || tunnel.Status == portainer.EdgeAgentIdle {
|
||||
continue
|
||||
}
|
||||
|
||||
tunnelPort := tunnel.Port
|
||||
elapsed := time.Since(tunnel.LastActivity)
|
||||
log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %d] [status: %s] [status_time_seconds: %f] [message: environment tunnel monitoring]", endpointID, tunnel.Status, elapsed.Seconds())
|
||||
|
||||
service.mu.RUnlock()
|
||||
|
||||
log.Debug().
|
||||
Int("endpoint_id", int(endpointID)).
|
||||
Float64("last_activity_seconds", elapsed.Seconds()).
|
||||
Float64("timeout_seconds", activeTimeout.Seconds()).
|
||||
Msg("last activity timeout exceeded")
|
||||
|
||||
if err := service.snapshotEnvironment(endpointID, tunnelPort); err != nil {
|
||||
log.Error().
|
||||
Int("endpoint_id", int(endpointID)).
|
||||
Err(err).
|
||||
Msg("unable to snapshot Edge environment")
|
||||
if tunnel.Status == portainer.EdgeAgentManagementRequired && elapsed.Seconds() < requiredTimeout.Seconds() {
|
||||
continue
|
||||
} else if tunnel.Status == portainer.EdgeAgentManagementRequired && elapsed.Seconds() > requiredTimeout.Seconds() {
|
||||
log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %d] [status: %s] [status_time_seconds: %f] [timeout_seconds: %f] [message: REQUIRED state timeout exceeded]", endpointID, tunnel.Status, elapsed.Seconds(), requiredTimeout.Seconds())
|
||||
}
|
||||
|
||||
service.close(endpointID)
|
||||
if tunnel.Status == portainer.EdgeAgentActive && elapsed.Seconds() < activeTimeout.Seconds() {
|
||||
continue
|
||||
} else if tunnel.Status == portainer.EdgeAgentActive && elapsed.Seconds() > activeTimeout.Seconds() {
|
||||
log.Printf("[DEBUG] [chisel,monitoring] [endpoint_id: %d] [status: %s] [status_time_seconds: %f] [timeout_seconds: %f] [message: ACTIVE state timeout exceeded]", endpointID, tunnel.Status, elapsed.Seconds(), activeTimeout.Seconds())
|
||||
|
||||
return
|
||||
err := service.snapshotEnvironment(endpointID, tunnel.Port)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] [snapshot] Unable to snapshot Edge environment (id: %d): %s", endpointID, err)
|
||||
}
|
||||
}
|
||||
|
||||
service.SetTunnelStatusToIdle(portainer.EndpointID(endpointID))
|
||||
}
|
||||
|
||||
service.mu.RUnlock()
|
||||
}
|
||||
|
||||
func (service *Service) snapshotEnvironment(endpointID portainer.EndpointID, tunnelPort int) error {
|
||||
@@ -299,7 +224,14 @@ func (service *Service) snapshotEnvironment(endpointID portainer.EndpointID, tun
|
||||
return err
|
||||
}
|
||||
|
||||
endpoint.URL = fmt.Sprintf("tcp://127.0.0.1:%d", tunnelPort)
|
||||
endpointURL := endpoint.URL
|
||||
|
||||
return service.snapshotService.SnapshotEndpoint(endpoint)
|
||||
endpoint.URL = fmt.Sprintf("tcp://127.0.0.1:%d", tunnelPort)
|
||||
err = service.snapshotService.SnapshotEndpoint(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
endpoint.URL = endpointURL
|
||||
return service.dataStore.Endpoint().UpdateEndpoint(endpoint.ID, endpoint)
|
||||
}
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package chisel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
|
||||
func TestPingAgentPanic(t *testing.T) {
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: 1,
|
||||
EdgeID: "test-edge-id",
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
UserTrusted: true,
|
||||
}
|
||||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
s := NewService(store, nil, nil)
|
||||
|
||||
defer func() {
|
||||
require.Nil(t, recover())
|
||||
}()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(pingTimeout + 1*time.Second)
|
||||
})
|
||||
|
||||
ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := &http.Server{Handler: mux}
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- srv.Serve(ln)
|
||||
}()
|
||||
|
||||
err = s.Open(endpoint)
|
||||
require.NoError(t, err)
|
||||
s.activeTunnels[endpoint.ID].Port = ln.Addr().(*net.TCPAddr).Port
|
||||
|
||||
require.Error(t, s.pingAgent(endpoint.ID))
|
||||
require.NoError(t, srv.Shutdown(context.Background()))
|
||||
require.ErrorIs(t, <-errCh, http.ErrServerClosed)
|
||||
}
|
||||
@@ -2,21 +2,14 @@ package chisel
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/pkg/libcrypto"
|
||||
"github.com/portainer/portainer/pkg/librand"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/portainer/libcrypto"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,206 +17,153 @@ const (
|
||||
maxAvailablePort = 65535
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNonEdgeEnv = errors.New("cannot open a tunnel for non-edge environments")
|
||||
ErrAsyncEnv = errors.New("cannot open a tunnel for async edge environments")
|
||||
ErrInvalidEnv = errors.New("cannot open a tunnel for an invalid environment")
|
||||
)
|
||||
|
||||
// Open will mark the tunnel as REQUIRED so the agent opens it
|
||||
func (s *Service) Open(endpoint *portainer.Endpoint) error {
|
||||
if !endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
return ErrNonEdgeEnv
|
||||
}
|
||||
|
||||
if endpoint.Edge.AsyncMode {
|
||||
return ErrAsyncEnv
|
||||
}
|
||||
|
||||
if endpoint.ID == 0 || endpoint.EdgeID == "" || !endpoint.UserTrusted {
|
||||
return ErrInvalidEnv
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, ok := s.activeTunnels[endpoint.ID]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer cache.Del(endpoint.ID)
|
||||
|
||||
tun := &portainer.TunnelDetails{
|
||||
Status: portainer.EdgeAgentManagementRequired,
|
||||
Port: s.getUnusedPort(),
|
||||
LastActivity: time.Now(),
|
||||
}
|
||||
|
||||
username, password := generateRandomCredentials()
|
||||
|
||||
if s.chiselServer != nil {
|
||||
authorizedRemote := fmt.Sprintf("^R:0.0.0.0:%d$", tun.Port)
|
||||
|
||||
if err := s.chiselServer.AddUser(username, password, authorizedRemote); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
credentials, err := encryptCredentials(username, password, endpoint.EdgeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tun.Credentials = credentials
|
||||
|
||||
s.activeTunnels[endpoint.ID] = tun
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// close removes the tunnel from the map so the agent will close it
|
||||
func (s *Service) close(endpointID portainer.EndpointID) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
tun, ok := s.activeTunnels[endpointID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if len(tun.Credentials) > 0 && s.chiselServer != nil {
|
||||
user, _, _ := strings.Cut(tun.Credentials, ":")
|
||||
s.chiselServer.DeleteUser(user)
|
||||
}
|
||||
|
||||
if s.ProxyManager != nil {
|
||||
s.ProxyManager.DeleteEndpointProxy(endpointID)
|
||||
}
|
||||
|
||||
delete(s.activeTunnels, endpointID)
|
||||
|
||||
cache.Del(endpointID)
|
||||
}
|
||||
|
||||
// Config returns the tunnel details needed for the agent to connect
|
||||
func (s *Service) Config(endpointID portainer.EndpointID) portainer.TunnelDetails {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if tun, ok := s.activeTunnels[endpointID]; ok {
|
||||
return *tun
|
||||
}
|
||||
|
||||
return portainer.TunnelDetails{Status: portainer.EdgeAgentIdle}
|
||||
}
|
||||
|
||||
// TunnelAddr returns the address of the local tunnel, including the port, it
|
||||
// will block until the tunnel is ready
|
||||
func (s *Service) TunnelAddr(endpoint *portainer.Endpoint) (string, error) {
|
||||
if err := s.Open(endpoint); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tun := s.Config(endpoint.ID)
|
||||
checkinInterval := time.Duration(s.tryEffectiveCheckinInterval(endpoint)) * time.Second
|
||||
|
||||
for t0 := time.Now(); ; {
|
||||
if time.Since(t0) > 2*checkinInterval {
|
||||
s.close(endpoint.ID)
|
||||
|
||||
return "", errors.New("unable to open the tunnel")
|
||||
}
|
||||
|
||||
// Check if the tunnel is established
|
||||
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: tun.Port})
|
||||
if err != nil {
|
||||
time.Sleep(checkinInterval / 100)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to close tcp connection")
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
s.UpdateLastActivity(endpoint.ID)
|
||||
|
||||
return fmt.Sprintf("127.0.0.1:%d", tun.Port), nil
|
||||
}
|
||||
|
||||
// tryEffectiveCheckinInterval avoids a potential deadlock by returning a
|
||||
// previous known value after a timeout
|
||||
func (s *Service) tryEffectiveCheckinInterval(endpoint *portainer.Endpoint) int {
|
||||
ch := make(chan int, 1)
|
||||
|
||||
go func() {
|
||||
ch <- edge.EffectiveCheckinInterval(s.dataStore, endpoint)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.defaultCheckinInterval
|
||||
case i := <-ch:
|
||||
s.mu.Lock()
|
||||
s.defaultCheckinInterval = i
|
||||
s.mu.Unlock()
|
||||
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateLastActivity sets the current timestamp to avoid the tunnel timeout
|
||||
func (s *Service) UpdateLastActivity(endpointID portainer.EndpointID) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if tun, ok := s.activeTunnels[endpointID]; ok {
|
||||
tun.LastActivity = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: it needs to be called with the lock acquired
|
||||
// getUnusedPort is used to generate an unused random port in the dynamic port range.
|
||||
// Dynamic ports (also called private ports) are 49152 to 65535.
|
||||
func (service *Service) getUnusedPort() int {
|
||||
port := randomInt(minAvailablePort, maxAvailablePort)
|
||||
|
||||
for _, tunnel := range service.activeTunnels {
|
||||
for _, tunnel := range service.tunnelDetailsMap {
|
||||
if tunnel.Port == port {
|
||||
return service.getUnusedPort()
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
|
||||
if err == nil {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warn().Msg("failed to close tcp connection that checks if port is free")
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("port", port).
|
||||
Msg("selected port is in use, trying a different one")
|
||||
|
||||
return service.getUnusedPort()
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
func randomInt(min, max int) int {
|
||||
return min + librand.Intn(max-min)
|
||||
return min + rand.Intn(max-min)
|
||||
}
|
||||
|
||||
// NOTE: it needs to be called with the lock acquired
|
||||
func (service *Service) getTunnelDetails(endpointID portainer.EndpointID) *portainer.TunnelDetails {
|
||||
|
||||
if tunnel, ok := service.tunnelDetailsMap[endpointID]; ok {
|
||||
return tunnel
|
||||
}
|
||||
|
||||
tunnel := &portainer.TunnelDetails{
|
||||
Status: portainer.EdgeAgentIdle,
|
||||
}
|
||||
|
||||
service.tunnelDetailsMap[endpointID] = tunnel
|
||||
|
||||
return tunnel
|
||||
}
|
||||
|
||||
// GetTunnelDetails returns information about the tunnel associated to an environment(endpoint).
|
||||
func (service *Service) GetTunnelDetails(endpointID portainer.EndpointID) portainer.TunnelDetails {
|
||||
service.mu.Lock()
|
||||
defer service.mu.Unlock()
|
||||
|
||||
return *service.getTunnelDetails(endpointID)
|
||||
}
|
||||
|
||||
// GetActiveTunnel retrieves an active tunnel which allows communicating with edge agent
|
||||
func (service *Service) GetActiveTunnel(endpoint *portainer.Endpoint) (portainer.TunnelDetails, error) {
|
||||
tunnel := service.GetTunnelDetails(endpoint.ID)
|
||||
|
||||
if tunnel.Status == portainer.EdgeAgentActive {
|
||||
// update the LastActivity
|
||||
service.SetTunnelStatusToActive(endpoint.ID)
|
||||
}
|
||||
|
||||
if tunnel.Status == portainer.EdgeAgentIdle || tunnel.Status == portainer.EdgeAgentManagementRequired {
|
||||
err := service.SetTunnelStatusToRequired(endpoint.ID)
|
||||
if err != nil {
|
||||
return portainer.TunnelDetails{}, fmt.Errorf("failed opening tunnel to endpoint: %w", err)
|
||||
}
|
||||
|
||||
if endpoint.EdgeCheckinInterval == 0 {
|
||||
settings, err := service.dataStore.Settings().Settings()
|
||||
if err != nil {
|
||||
return portainer.TunnelDetails{}, fmt.Errorf("failed fetching settings from db: %w", err)
|
||||
}
|
||||
|
||||
endpoint.EdgeCheckinInterval = settings.EdgeAgentCheckinInterval
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Duration(endpoint.EdgeCheckinInterval) * time.Second)
|
||||
}
|
||||
|
||||
return service.GetTunnelDetails(endpoint.ID), nil
|
||||
}
|
||||
|
||||
// SetTunnelStatusToActive update the status of the tunnel associated to the specified environment(endpoint).
|
||||
// It sets the status to ACTIVE.
|
||||
func (service *Service) SetTunnelStatusToActive(endpointID portainer.EndpointID) {
|
||||
service.mu.Lock()
|
||||
tunnel := service.getTunnelDetails(endpointID)
|
||||
tunnel.Status = portainer.EdgeAgentActive
|
||||
tunnel.Credentials = ""
|
||||
tunnel.LastActivity = time.Now()
|
||||
service.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetTunnelStatusToIdle update the status of the tunnel associated to the specified environment(endpoint).
|
||||
// It sets the status to IDLE.
|
||||
// It removes any existing credentials associated to the tunnel.
|
||||
func (service *Service) SetTunnelStatusToIdle(endpointID portainer.EndpointID) {
|
||||
service.mu.Lock()
|
||||
|
||||
tunnel := service.getTunnelDetails(endpointID)
|
||||
tunnel.Status = portainer.EdgeAgentIdle
|
||||
tunnel.Port = 0
|
||||
tunnel.LastActivity = time.Now()
|
||||
|
||||
credentials := tunnel.Credentials
|
||||
if credentials != "" {
|
||||
tunnel.Credentials = ""
|
||||
service.chiselServer.DeleteUser(strings.Split(credentials, ":")[0])
|
||||
}
|
||||
|
||||
service.ProxyManager.DeleteEndpointProxy(endpointID)
|
||||
|
||||
service.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetTunnelStatusToRequired update the status of the tunnel associated to the specified environment(endpoint).
|
||||
// It sets the status to REQUIRED.
|
||||
// If no port is currently associated to the tunnel, it will associate a random unused port to the tunnel
|
||||
// and generate temporary credentials that can be used to establish a reverse tunnel on that port.
|
||||
// Credentials are encrypted using the Edge ID associated to the environment(endpoint).
|
||||
func (service *Service) SetTunnelStatusToRequired(endpointID portainer.EndpointID) error {
|
||||
tunnel := service.getTunnelDetails(endpointID)
|
||||
|
||||
service.mu.Lock()
|
||||
defer service.mu.Unlock()
|
||||
|
||||
if tunnel.Port == 0 {
|
||||
endpoint, err := service.dataStore.Endpoint().Endpoint(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tunnel.Status = portainer.EdgeAgentManagementRequired
|
||||
tunnel.Port = service.getUnusedPort()
|
||||
tunnel.LastActivity = time.Now()
|
||||
|
||||
username, password := generateRandomCredentials()
|
||||
authorizedRemote := fmt.Sprintf("^R:0.0.0.0:%d$", tunnel.Port)
|
||||
err = service.chiselServer.AddUser(username, password, authorizedRemote)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
credentials, err := encryptCredentials(username, password, endpoint.EdgeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.Credentials = credentials
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateRandomCredentials() (string, string) {
|
||||
username := uniuri.NewLen(8)
|
||||
password := uniuri.NewLen(8)
|
||||
|
||||
return username, password
|
||||
}
|
||||
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package chisel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
type testSettingsService struct {
|
||||
dataservices.SettingsService
|
||||
}
|
||||
|
||||
func (s *testSettingsService) Settings() (*portainer.Settings, error) {
|
||||
return &portainer.Settings{
|
||||
EdgeAgentCheckinInterval: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type testStore struct {
|
||||
dataservices.DataStore
|
||||
}
|
||||
|
||||
func (s *testStore) Settings() dataservices.SettingsService {
|
||||
return &testSettingsService{}
|
||||
}
|
||||
|
||||
func TestGetUnusedPort(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
existingTunnels map[portainer.EndpointID]*portainer.TunnelDetails
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "simple case",
|
||||
},
|
||||
{
|
||||
name: "existing tunnels",
|
||||
existingTunnels: map[portainer.EndpointID]*portainer.TunnelDetails{
|
||||
portainer.EndpointID(1): {
|
||||
Port: 53072,
|
||||
},
|
||||
portainer.EndpointID(2): {
|
||||
Port: 63072,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
store := &testStore{}
|
||||
s := NewService(store, nil, nil)
|
||||
s.activeTunnels = tc.existingTunnels
|
||||
port := s.getUnusedPort()
|
||||
|
||||
if port < 49152 || port > 65535 {
|
||||
t.Fatalf("Expected port to be inbetween 49152 and 65535 but got %d", port)
|
||||
}
|
||||
|
||||
for _, tun := range tc.existingTunnels {
|
||||
if tun.Port == port {
|
||||
t.Fatalf("returned port %d already has an existing tunnel", port)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
|
||||
if err == nil {
|
||||
// Ignore error
|
||||
_ = conn.Close()
|
||||
t.Fatalf("expected port %d to be unused", port)
|
||||
} else if !strings.Contains(err.Error(), "connection refused") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
179
api/cli/cli.go
179
api/cli/cli.go
@@ -2,43 +2,55 @@ package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
// Service implements the CLIService interface
|
||||
type Service struct{}
|
||||
|
||||
var (
|
||||
ErrInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
|
||||
ErrSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
|
||||
ErrInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
|
||||
ErrAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
|
||||
errInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
|
||||
errSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
|
||||
errInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
|
||||
errAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
|
||||
)
|
||||
|
||||
func CLIFlags() *portainer.CLIFlags {
|
||||
return &portainer.CLIFlags{
|
||||
// ParseFlags parse the CLI flags and return a portainer.Flags struct
|
||||
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
flags := &portainer.CLIFlags{
|
||||
Addr: kingpin.Flag("bind", "Address and port to serve Portainer").Default(defaultBindAddress).Short('p').String(),
|
||||
AddrHTTPS: kingpin.Flag("bind-https", "Address and port to serve Portainer via https").Default(defaultHTTPSBindAddress).String(),
|
||||
TunnelAddr: kingpin.Flag("tunnel-addr", "Address to serve the tunnel server").Default(defaultTunnelServerAddress).String(),
|
||||
TunnelPort: kingpin.Flag("tunnel-port", "Port to serve the tunnel server").Default(defaultTunnelServerPort).String(),
|
||||
Assets: kingpin.Flag("assets", "Path to the assets").Default(defaultAssetsDirectory).Short('a').String(),
|
||||
Data: kingpin.Flag("data", "Path to the folder where the data is stored").Default(defaultDataDirectory).Short('d').String(),
|
||||
DemoEnvironment: kingpin.Flag("demo", "Demo environment").Bool(),
|
||||
EndpointURL: kingpin.Flag("host", "Environment URL").Short('H').String(),
|
||||
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Envar(portainer.FeatureFlagEnvVar).Strings(),
|
||||
FeatureFlags: BoolPairs(kingpin.Flag("feat", "List of feature flags").Hidden()),
|
||||
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
|
||||
NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(),
|
||||
TLS: kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).Bool(),
|
||||
TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(),
|
||||
TLSCacert: kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String(),
|
||||
TLSCert: kingpin.Flag("tlscert", "Path to the TLS certificate file").Default(defaultTLSCertPath).String(),
|
||||
TLSKey: kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).String(),
|
||||
HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(),
|
||||
HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(),
|
||||
Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(),
|
||||
SSL: kingpin.Flag("ssl", "Secure Portainer instance using SSL (deprecated)").Default(defaultSSL).Bool(),
|
||||
SSLCert: kingpin.Flag("sslcert", "Path to the SSL certificate used to secure the Portainer instance").String(),
|
||||
SSLKey: kingpin.Flag("sslkey", "Path to the SSL key used to secure the Portainer instance").String(),
|
||||
Rollback: kingpin.Flag("rollback", "Rollback the database store to the previous version").Bool(),
|
||||
SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(),
|
||||
AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(),
|
||||
AdminPasswordFile: kingpin.Flag("admin-password-file", "Path to the file containing the password for the admin user").String(),
|
||||
@@ -50,54 +62,7 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
MaxBatchSize: kingpin.Flag("max-batch-size", "Maximum size of a batch").Int(),
|
||||
MaxBatchDelay: kingpin.Flag("max-batch-delay", "Maximum delay before a batch starts").Duration(),
|
||||
SecretKeyName: kingpin.Flag("secret-key-name", "Secret key name for encryption and will be used as /run/secrets/<secret-key-name>.").Default(defaultSecretKeyName).String(),
|
||||
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
|
||||
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
|
||||
PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(),
|
||||
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
|
||||
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
|
||||
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
|
||||
}
|
||||
}
|
||||
|
||||
// ParseFlags parse the CLI flags and return a portainer.Flags struct
|
||||
func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
var hasSSLFlag, hasSSLCertFlag, hasSSLKeyFlag bool
|
||||
sslFlag := kingpin.Flag(
|
||||
"ssl",
|
||||
"Secure Portainer instance using SSL (deprecated)",
|
||||
).Default(defaultSSL).IsSetByUser(&hasSSLFlag)
|
||||
ssl := sslFlag.Bool()
|
||||
sslCertFlag := kingpin.Flag(
|
||||
"sslcert",
|
||||
"Path to the SSL certificate used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLCertFlag)
|
||||
sslCert := sslCertFlag.String()
|
||||
sslKeyFlag := kingpin.Flag(
|
||||
"sslkey",
|
||||
"Path to the SSL key used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLKeyFlag)
|
||||
sslKey := sslKeyFlag.String()
|
||||
|
||||
flags := CLIFlags()
|
||||
|
||||
var hasTLSFlag, hasTLSCertFlag, hasTLSKeyFlag bool
|
||||
tlsFlag := kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).IsSetByUser(&hasTLSFlag)
|
||||
flags.TLS = tlsFlag.Bool()
|
||||
tlsCertFlag := kingpin.Flag(
|
||||
"tlscert",
|
||||
"Path to the TLS certificate file",
|
||||
).Default(defaultTLSCertPath).IsSetByUser(&hasTLSCertFlag)
|
||||
flags.TLSCert = tlsCertFlag.String()
|
||||
tlsKeyFlag := kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).IsSetByUser(&hasTLSKeyFlag)
|
||||
flags.TLSKey = tlsKeyFlag.String()
|
||||
flags.TLSCacert = kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String()
|
||||
|
||||
flags.KubectlShellImage = kingpin.Flag(
|
||||
"kubectl-shell-image",
|
||||
"Kubectl shell image",
|
||||
).Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String()
|
||||
|
||||
kingpin.Parse()
|
||||
|
||||
@@ -106,62 +71,29 @@ func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
*flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets)
|
||||
}
|
||||
|
||||
// If the user didn't provide a tls flag remove the defaults to match previous behaviour
|
||||
if !hasTLSFlag {
|
||||
if !hasTLSCertFlag {
|
||||
*flags.TLSCert = ""
|
||||
}
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
*flags.TLSKey = ""
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslFlag.Model().Name, tlsFlag.Model().Name)
|
||||
|
||||
if !hasTLSFlag {
|
||||
flags.TLS = ssl
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLCertFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslCertFlag.Model().Name, tlsCertFlag.Model().Name)
|
||||
|
||||
if !hasTLSCertFlag {
|
||||
flags.TLSCert = sslCert
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLKeyFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslKeyFlag.Model().Name, tlsKeyFlag.Model().Name)
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
flags.TLSKey = sslKey
|
||||
}
|
||||
}
|
||||
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
// ValidateFlags validates the values of the flags.
|
||||
func (Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
|
||||
displayDeprecationWarnings(flags)
|
||||
|
||||
if err := validateEndpointURL(*flags.EndpointURL); err != nil {
|
||||
err := validateEndpointURL(*flags.EndpointURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateSnapshotInterval(*flags.SnapshotInterval); err != nil {
|
||||
err = validateSnapshotInterval(*flags.SnapshotInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *flags.AdminPassword != "" && *flags.AdminPasswordFile != "" {
|
||||
return ErrAdminPassExcludeAdminPassFile
|
||||
return errAdminPassExcludeAdminPassFile
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -169,43 +101,40 @@ func (Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
|
||||
func displayDeprecationWarnings(flags *portainer.CLIFlags) {
|
||||
if *flags.NoAnalytics {
|
||||
log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect")
|
||||
log.Println("Warning: The --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect.")
|
||||
}
|
||||
|
||||
if *flags.SSL {
|
||||
log.Println("Warning: SSL is enabled by default and there is no need for the --ssl flag. It has been kept to allow migration of instances running a previous version of Portainer with this flag enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpointURL(endpointURL string) error {
|
||||
if endpointURL == "" {
|
||||
return nil
|
||||
}
|
||||
if endpointURL != "" {
|
||||
if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") {
|
||||
return errInvalidEndpointProtocol
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") {
|
||||
return ErrInvalidEndpointProtocol
|
||||
}
|
||||
|
||||
if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") {
|
||||
socketPath := strings.TrimPrefix(endpointURL, "unix://")
|
||||
socketPath = strings.TrimPrefix(socketPath, "npipe://")
|
||||
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return ErrSocketOrNamedPipeNotFound
|
||||
if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") {
|
||||
socketPath := strings.TrimPrefix(endpointURL, "unix://")
|
||||
socketPath = strings.TrimPrefix(socketPath, "npipe://")
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return errSocketOrNamedPipeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSnapshotInterval(snapshotInterval string) error {
|
||||
if snapshotInterval == "" {
|
||||
return nil
|
||||
if snapshotInterval != "" {
|
||||
_, err := time.ParseDuration(snapshotInterval)
|
||||
if err != nil {
|
||||
return errInvalidSnapshotInterval
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := time.ParseDuration(snapshotInterval); err != nil {
|
||||
return ErrInvalidSnapshotInterval
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
zerolog "github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOptionParser(t *testing.T) {
|
||||
p := Service{}
|
||||
require.NotNil(t, p)
|
||||
|
||||
a := os.Args
|
||||
defer func() { os.Args = a }()
|
||||
|
||||
os.Args = []string{"portainer", "--edge-compute"}
|
||||
|
||||
opts, err := p.ParseFlags("2.34.5")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.False(t, *opts.HTTPDisabled)
|
||||
require.True(t, *opts.EnableEdgeComputeFeatures)
|
||||
}
|
||||
|
||||
func TestParseTLSFlags(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedTLSFlag bool
|
||||
expectedTLSCertFlag string
|
||||
expectedTLSKeyFlag string
|
||||
expectedLogMessages []string
|
||||
}{
|
||||
{
|
||||
name: "no flags",
|
||||
expectedTLSFlag: false,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only ssl flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only tls flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: defaultTLSCertPath,
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "partial tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags 2",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var logOutput strings.Builder
|
||||
setupLogOutput(t, &logOutput)
|
||||
|
||||
if tc.args == nil {
|
||||
tc.args = []string{"portainer"}
|
||||
}
|
||||
setOsArgs(t, tc.args)
|
||||
|
||||
s := Service{}
|
||||
flags, err := s.ParseFlags("test-version")
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing flags: %v", err)
|
||||
}
|
||||
|
||||
if flags.TLS == nil {
|
||||
t.Fatal("TLS flag was nil")
|
||||
}
|
||||
|
||||
require.Equal(t, tc.expectedTLSFlag, *flags.TLS, "tlsverify flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSCertFlag, *flags.TLSCert, "tlscert flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSKeyFlag, *flags.TLSKey, "tlskey flag didn't match")
|
||||
|
||||
for _, expectedLogMessage := range tc.expectedLogMessages {
|
||||
require.Contains(t, logOutput.String(), expectedLogMessage, "Log didn't contain expected message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setOsArgs(t *testing.T, args []string) {
|
||||
t.Helper()
|
||||
previousArgs := os.Args
|
||||
os.Args = args
|
||||
t.Cleanup(func() {
|
||||
os.Args = previousArgs
|
||||
})
|
||||
}
|
||||
|
||||
func setupLogOutput(t *testing.T, w io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
oldLogger := zerolog.Logger
|
||||
zerolog.Logger = zerolog.Output(w)
|
||||
t.Cleanup(func() {
|
||||
zerolog.Logger = oldLogger
|
||||
})
|
||||
}
|
||||
@@ -2,22 +2,23 @@ package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Confirm starts a rollback db cli application
|
||||
func Confirm(message string) (bool, error) {
|
||||
fmt.Printf("%s [y/N] ", message)
|
||||
log.Printf("%s [y/N]", message)
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
answer, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
answer = strings.Replace(answer, "\n", "", -1)
|
||||
answer = strings.ToLower(answer)
|
||||
|
||||
return answer == "y" || answer == "yes", nil
|
||||
|
||||
answer = strings.ReplaceAll(answer, "\n", "")
|
||||
return strings.EqualFold(answer, "y") || strings.EqualFold(answer, "yes"), nil
|
||||
}
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package cli
|
||||
|
||||
const (
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "/data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "/certs/ca.pem"
|
||||
defaultTLSCertPath = "/certs/cert.pem"
|
||||
defaultTLSKeyPath = "/certs/key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
defaultPullLimitCheckDisabled = "false"
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "/data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "/certs/ca.pem"
|
||||
defaultTLSCertPath = "/certs/cert.pem"
|
||||
defaultTLSKeyPath = "/certs/key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
)
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
package cli
|
||||
|
||||
const (
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "C:\\data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "C:\\certs\\ca.pem"
|
||||
defaultTLSCertPath = "C:\\certs\\cert.pem"
|
||||
defaultTLSKeyPath = "C:\\certs\\key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultSnapshotInterval = "5m"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
defaultPullLimitCheckDisabled = "false"
|
||||
defaultBindAddress = ":9000"
|
||||
defaultHTTPSBindAddress = ":9443"
|
||||
defaultTunnelServerAddress = "0.0.0.0"
|
||||
defaultTunnelServerPort = "8000"
|
||||
defaultDataDirectory = "C:\\data"
|
||||
defaultAssetsDirectory = "./"
|
||||
defaultTLS = "false"
|
||||
defaultTLSSkipVerify = "false"
|
||||
defaultTLSCACertPath = "C:\\certs\\ca.pem"
|
||||
defaultTLSCertPath = "C:\\certs\\cert.pem"
|
||||
defaultTLSKeyPath = "C:\\certs\\key.pem"
|
||||
defaultHTTPDisabled = "false"
|
||||
defaultHTTPEnabled = "false"
|
||||
defaultSSL = "false"
|
||||
defaultSnapshotInterval = "5m"
|
||||
defaultBaseURL = "/"
|
||||
defaultSecretKeyName = "portainer"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
type pairList []portainer.Pair
|
||||
|
||||
45
api/cli/pairlistbool.go
Normal file
45
api/cli/pairlistbool.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"strings"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
type pairListBool []portainer.Pair
|
||||
|
||||
// Set implementation for a list of portainer.Pair
|
||||
func (l *pairListBool) Set(value string) error {
|
||||
p := new(portainer.Pair)
|
||||
|
||||
// default to true. example setting=true is equivalent to setting
|
||||
parts := strings.SplitN(value, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
p.Name = parts[0]
|
||||
p.Value = "true"
|
||||
} else {
|
||||
p.Name = parts[0]
|
||||
p.Value = parts[1]
|
||||
}
|
||||
|
||||
*l = append(*l, *p)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implementation for a list of pair
|
||||
func (l *pairListBool) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsCumulative implementation for a list of pair
|
||||
func (l *pairListBool) IsCumulative() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func BoolPairs(s kingpin.Settings) (target *[]portainer.Pair) {
|
||||
target = new([]portainer.Pair)
|
||||
s.SetValue((*pairListBool)(target))
|
||||
return
|
||||
}
|
||||
29
api/cmd/portainer/import.go
Normal file
29
api/cmd/portainer/import.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func importFromJson(fileService portainer.FileService, store *datastore.Store) {
|
||||
// EXPERIMENTAL - if used with an incomplete json file, it will fail, as we don't have a way to default the model values
|
||||
importFile := "/data/import.json"
|
||||
if exists, _ := fileService.FileExists(importFile); exists {
|
||||
if err := store.Import(importFile); err != nil {
|
||||
logrus.WithError(err).Debugf("Import %s failed", importFile)
|
||||
|
||||
// TODO: should really rollback on failure, but then we have nothing.
|
||||
} else {
|
||||
logrus.Printf("Successfully imported %s to new portainer database", importFile)
|
||||
}
|
||||
// TODO: this is bad - its to ensure that any defaults that were broken in import, or migrations get set back to what we want
|
||||
// I also suspect that everything from "Init to Init" is potentially a migration
|
||||
err := store.Init()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed initializing data store: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
20
api/cmd/portainer/log.go
Normal file
20
api/cmd/portainer/log.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func configureLogger() {
|
||||
logger := logrus.New() // logger is to implicitly substitute stdlib's log
|
||||
log.SetOutput(logger.Writer())
|
||||
|
||||
formatter := &logrus.TextFormatter{DisableTimestamp: false, DisableLevelTruncation: true}
|
||||
|
||||
logger.SetFormatter(formatter)
|
||||
logrus.SetFormatter(formatter)
|
||||
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
@@ -1,13 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/portainer/libhelm"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/apikey"
|
||||
"github.com/portainer/portainer/api/chisel"
|
||||
@@ -15,78 +21,57 @@ import (
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/database"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/database/models"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/datastore/migrator"
|
||||
"github.com/portainer/portainer/api/datastore/postinit"
|
||||
"github.com/portainer/portainer/api/demo"
|
||||
"github.com/portainer/portainer/api/docker"
|
||||
dockerclient "github.com/portainer/portainer/api/docker/client"
|
||||
"github.com/portainer/portainer/api/exec"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/api/git"
|
||||
"github.com/portainer/portainer/api/hostmanagement/openamt"
|
||||
"github.com/portainer/portainer/api/http"
|
||||
"github.com/portainer/portainer/api/http/client"
|
||||
"github.com/portainer/portainer/api/http/proxy"
|
||||
kubeproxy "github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/internal/edge/edgestacks"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/snapshot"
|
||||
"github.com/portainer/portainer/api/internal/ssl"
|
||||
"github.com/portainer/portainer/api/internal/upgrade"
|
||||
"github.com/portainer/portainer/api/jwt"
|
||||
"github.com/portainer/portainer/api/kubernetes"
|
||||
kubecli "github.com/portainer/portainer/api/kubernetes/cli"
|
||||
"github.com/portainer/portainer/api/ldap"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/api/oauth"
|
||||
"github.com/portainer/portainer/api/pendingactions"
|
||||
"github.com/portainer/portainer/api/pendingactions/actions"
|
||||
"github.com/portainer/portainer/api/pendingactions/handlers"
|
||||
"github.com/portainer/portainer/api/platform"
|
||||
"github.com/portainer/portainer/api/scheduler"
|
||||
"github.com/portainer/portainer/api/stacks/deployments"
|
||||
"github.com/portainer/portainer/pkg/build"
|
||||
"github.com/portainer/portainer/pkg/featureflags"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/portainer/portainer/pkg/libhelm"
|
||||
libhelmtypes "github.com/portainer/portainer/pkg/libhelm/types"
|
||||
"github.com/portainer/portainer/pkg/libstack/compose"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/portainer/portainer/api/stacks"
|
||||
)
|
||||
|
||||
func initCLI() *portainer.CLIFlags {
|
||||
cliService := cli.Service{}
|
||||
|
||||
var cliService portainer.CLIService = &cli.Service{}
|
||||
flags, err := cliService.ParseFlags(portainer.APIVersion)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed parsing flags")
|
||||
logrus.Fatalf("Failed parsing flags: %v", err)
|
||||
}
|
||||
|
||||
if err := cliService.ValidateFlags(flags); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed validating flags")
|
||||
err = cliService.ValidateFlags(flags)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed validating flags:%v", err)
|
||||
}
|
||||
|
||||
return flags
|
||||
}
|
||||
|
||||
func initFileService(dataStorePath string) portainer.FileService {
|
||||
fileService, err := filesystem.NewService(dataStorePath, "")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed creating file service")
|
||||
logrus.Fatalf("Failed creating file service: %v", err)
|
||||
}
|
||||
|
||||
return fileService
|
||||
}
|
||||
|
||||
func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService portainer.FileService, shutdownCtx context.Context) dataservices.DataStore {
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey, *flags.CompactDB)
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed creating database connection")
|
||||
logrus.Fatalf("failed creating database connection: %s", err)
|
||||
}
|
||||
|
||||
if bconn, ok := connection.(*boltdb.DbConnection); ok {
|
||||
@@ -94,109 +79,138 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
|
||||
bconn.MaxBatchDelay = *flags.MaxBatchDelay
|
||||
bconn.InitialMmapSize = *flags.InitialMmapSize
|
||||
} else {
|
||||
log.Fatal().Msg("failed creating database connection: expecting a boltdb database type but a different one was received")
|
||||
logrus.Fatalf("failed creating database connection: expecting a boltdb database type but a different one was received")
|
||||
}
|
||||
|
||||
store := datastore.NewStore(flags, fileService, connection)
|
||||
|
||||
store := datastore.NewStore(*flags.Data, fileService, connection)
|
||||
isNew, err := store.Open()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed opening store")
|
||||
logrus.Fatalf("Failed opening store: %v", err)
|
||||
}
|
||||
|
||||
if *flags.Rollback {
|
||||
if err := store.Rollback(false); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed rolling back")
|
||||
err := store.Rollback(false)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed rolling back: %v", err)
|
||||
}
|
||||
|
||||
log.Info().Msg("exiting rollback")
|
||||
logrus.Println("Exiting rollback")
|
||||
os.Exit(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init sets some defaults - it's basically a migration
|
||||
if err := store.Init(); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing data store")
|
||||
err = store.Init()
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed initializing data store: %v", err)
|
||||
}
|
||||
|
||||
if isNew {
|
||||
instanceId, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed generating instance id")
|
||||
}
|
||||
|
||||
migratorInstance := migrator.NewMigrator(&migrator.MigratorParameters{Flags: flags})
|
||||
migratorCount := migratorInstance.GetMigratorCountOfCurrentAPIVersion()
|
||||
|
||||
// from MigrateData
|
||||
v := models.Version{
|
||||
SchemaVersion: portainer.APIVersion,
|
||||
Edition: int(portainer.PortainerCE),
|
||||
InstanceID: instanceId.String(),
|
||||
MigratorCount: migratorCount,
|
||||
}
|
||||
store.VersionService.StoreDBVersion(portainer.DBVersion)
|
||||
|
||||
if err := store.VersionService.UpdateVersion(&v); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to update version")
|
||||
err := updateSettingsFromFlags(store, flags)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed updating settings from flags: %v", err)
|
||||
}
|
||||
|
||||
if err := updateSettingsFromFlags(store, flags); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed updating settings from flags")
|
||||
} else {
|
||||
storedVersion, err := store.VersionService.DBVersion()
|
||||
if err != nil {
|
||||
logrus.Fatalf("Something Failed during creation of new database: %v", err)
|
||||
}
|
||||
if storedVersion != portainer.DBVersion {
|
||||
err = store.MigrateData()
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed migration: %v", err)
|
||||
}
|
||||
}
|
||||
} else if err := store.MigrateData(); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed migration")
|
||||
}
|
||||
|
||||
if err := updateSettingsFromFlags(store, flags); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed updating settings from flags")
|
||||
err = updateSettingsFromFlags(store, flags)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed updating settings from flags: %v", err)
|
||||
}
|
||||
|
||||
// this is for the db restore functionality - needs more tests.
|
||||
go func() {
|
||||
<-shutdownCtx.Done()
|
||||
defer connection.Close()
|
||||
|
||||
defer logs.CloseAndLogErr(connection)
|
||||
exportFilename := path.Join(*flags.Data, fmt.Sprintf("export-%d.json", time.Now().Unix()))
|
||||
|
||||
err := store.Export(exportFilename)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Debugf("Failed to export to %s", exportFilename)
|
||||
} else {
|
||||
logrus.Debugf("exported to %s", exportFilename)
|
||||
}
|
||||
connection.Close()
|
||||
}()
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// checkDBSchemaServerVersionMatch checks if the server version matches the db scehma version
|
||||
func checkDBSchemaServerVersionMatch(dbStore dataservices.DataStore, serverVersion string, serverEdition int) bool {
|
||||
v, err := dbStore.Version().Version()
|
||||
func initComposeStackManager(assetsPath string, configPath string, reverseTunnelService portainer.ReverseTunnelService, proxyManager *proxy.Manager) portainer.ComposeStackManager {
|
||||
composeWrapper, err := exec.NewComposeStackManager(assetsPath, configPath, proxyManager)
|
||||
if err != nil {
|
||||
return false
|
||||
logrus.Fatalf("Failed creating compose manager: %v", err)
|
||||
}
|
||||
|
||||
return v.SchemaVersion == serverVersion && v.Edition == serverEdition
|
||||
return composeWrapper
|
||||
}
|
||||
|
||||
func initKubernetesDeployer(kubernetesTokenCacheManager *kubeproxy.TokenCacheManager, kubernetesClientFactory *kubecli.ClientFactory, dataStore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager) portainer.KubernetesDeployer {
|
||||
return exec.NewKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager)
|
||||
func initSwarmStackManager(
|
||||
assetsPath string,
|
||||
configPath string,
|
||||
signatureService portainer.DigitalSignatureService,
|
||||
fileService portainer.FileService,
|
||||
reverseTunnelService portainer.ReverseTunnelService,
|
||||
dataStore dataservices.DataStore,
|
||||
) (portainer.SwarmStackManager, error) {
|
||||
return exec.NewSwarmStackManager(assetsPath, configPath, signatureService, fileService, reverseTunnelService, dataStore)
|
||||
}
|
||||
|
||||
func initHelmPackageManager() (libhelmtypes.HelmPackageManager, error) {
|
||||
return libhelm.NewHelmPackageManager()
|
||||
func initKubernetesDeployer(kubernetesTokenCacheManager *kubeproxy.TokenCacheManager, kubernetesClientFactory *kubecli.ClientFactory, dataStore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager, assetsPath string) portainer.KubernetesDeployer {
|
||||
return exec.NewKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, assetsPath)
|
||||
}
|
||||
|
||||
func initHelmPackageManager(assetsPath string) (libhelm.HelmPackageManager, error) {
|
||||
return libhelm.NewHelmPackageManager(libhelm.HelmConfig{BinaryPath: assetsPath})
|
||||
}
|
||||
|
||||
func initAPIKeyService(datastore dataservices.DataStore) apikey.APIKeyService {
|
||||
return apikey.NewAPIKeyService(datastore.APIKeyRepository(), datastore.User())
|
||||
}
|
||||
|
||||
func initJWTService(userSessionTimeout string, dataStore dataservices.DataStore) (portainer.JWTService, error) {
|
||||
if userSessionTimeout == "" {
|
||||
userSessionTimeout = portainer.DefaultUserSessionTimeout
|
||||
func initJWTService(userSessionTimeout string, dataStore dataservices.DataStore) (dataservices.JWTService, error) {
|
||||
jwtService, err := jwt.NewService(userSessionTimeout, dataStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return jwt.NewService(userSessionTimeout, dataStore)
|
||||
return jwtService, nil
|
||||
}
|
||||
|
||||
func initDigitalSignatureService() portainer.DigitalSignatureService {
|
||||
return crypto.NewECDSAService(os.Getenv("AGENT_SECRET"))
|
||||
}
|
||||
|
||||
func initCryptoService() portainer.CryptoService {
|
||||
return &crypto.Service{}
|
||||
}
|
||||
|
||||
func initLDAPService() portainer.LDAPService {
|
||||
return &ldap.Service{}
|
||||
}
|
||||
|
||||
func initOAuthService() portainer.OAuthService {
|
||||
return oauth.NewService()
|
||||
}
|
||||
|
||||
func initGitService() portainer.GitService {
|
||||
return git.NewService()
|
||||
}
|
||||
|
||||
func initSSLService(addr, certPath, keyPath string, fileService portainer.FileService, dataStore dataservices.DataStore, shutdownTrigger context.CancelFunc) (*ssl.Service, error) {
|
||||
slices := strings.Split(addr, ":")
|
||||
|
||||
host := slices[0]
|
||||
if host == "" {
|
||||
host = "0.0.0.0"
|
||||
@@ -204,25 +218,27 @@ func initSSLService(addr, certPath, keyPath string, fileService portainer.FileSe
|
||||
|
||||
sslService := ssl.NewService(fileService, dataStore, shutdownTrigger)
|
||||
|
||||
if err := sslService.Init(host, certPath, keyPath); err != nil {
|
||||
err := sslService.Init(host, certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sslService, nil
|
||||
}
|
||||
|
||||
func initSnapshotService(
|
||||
snapshotIntervalFromFlag string,
|
||||
dataStore dataservices.DataStore,
|
||||
dockerClientFactory *dockerclient.ClientFactory,
|
||||
kubernetesClientFactory *kubecli.ClientFactory,
|
||||
shutdownCtx context.Context,
|
||||
pendingActionsService *pendingactions.PendingActionsService,
|
||||
) (portainer.SnapshotService, error) {
|
||||
func initDockerClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService) *docker.ClientFactory {
|
||||
return docker.NewClientFactory(signatureService, reverseTunnelService)
|
||||
}
|
||||
|
||||
func initKubernetesClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService, instanceID string, dataStore dataservices.DataStore) *kubecli.ClientFactory {
|
||||
return kubecli.NewClientFactory(signatureService, reverseTunnelService, instanceID, dataStore)
|
||||
}
|
||||
|
||||
func initSnapshotService(snapshotIntervalFromFlag string, dataStore dataservices.DataStore, dockerClientFactory *docker.ClientFactory, kubernetesClientFactory *kubecli.ClientFactory, shutdownCtx context.Context) (portainer.SnapshotService, error) {
|
||||
dockerSnapshotter := docker.NewSnapshotter(dockerClientFactory)
|
||||
kubernetesSnapshotter := kubernetes.NewSnapshotter(kubernetesClientFactory)
|
||||
|
||||
snapshotService, err := snapshot.NewService(snapshotIntervalFromFlag, dataStore, dockerSnapshotter, kubernetesSnapshotter, shutdownCtx, pendingActionsService)
|
||||
snapshotService, err := snapshot.NewService(snapshotIntervalFromFlag, dataStore, dockerSnapshotter, kubernetesSnapshotter, shutdownCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -243,21 +259,34 @@ func updateSettingsFromFlags(dataStore dataservices.DataStore, flags *portainer.
|
||||
return err
|
||||
}
|
||||
|
||||
settings.SnapshotInterval = cmp.Or(*flags.SnapshotInterval, settings.SnapshotInterval)
|
||||
settings.LogoURL = cmp.Or(*flags.Logo, settings.LogoURL)
|
||||
settings.EnableEdgeComputeFeatures = cmp.Or(*flags.EnableEdgeComputeFeatures, settings.EnableEdgeComputeFeatures)
|
||||
settings.TemplatesURL = cmp.Or(*flags.Templates, settings.TemplatesURL)
|
||||
if *flags.SnapshotInterval != "" {
|
||||
settings.SnapshotInterval = *flags.SnapshotInterval
|
||||
}
|
||||
|
||||
if *flags.Logo != "" {
|
||||
settings.LogoURL = *flags.Logo
|
||||
}
|
||||
|
||||
if *flags.EnableEdgeComputeFeatures {
|
||||
settings.EnableEdgeComputeFeatures = *flags.EnableEdgeComputeFeatures
|
||||
}
|
||||
|
||||
if *flags.Templates != "" {
|
||||
settings.TemplatesURL = *flags.Templates
|
||||
}
|
||||
|
||||
if *flags.Labels != nil {
|
||||
settings.BlackListedLabels = *flags.Labels
|
||||
}
|
||||
|
||||
settings.AgentSecret = ""
|
||||
if agentKey, ok := os.LookupEnv("AGENT_SECRET"); ok {
|
||||
settings.AgentSecret = agentKey
|
||||
} else {
|
||||
settings.AgentSecret = ""
|
||||
}
|
||||
|
||||
if err := dataStore.Settings().UpdateSettings(settings); err != nil {
|
||||
err = dataStore.Settings().UpdateSettings(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -272,7 +301,55 @@ func updateSettingsFromFlags(dataStore dataservices.DataStore, flags *portainer.
|
||||
sslSettings.HTTPEnabled = true
|
||||
}
|
||||
|
||||
return dataStore.SSLSettings().UpdateSettings(sslSettings)
|
||||
err = dataStore.SSLSettings().UpdateSettings(sslSettings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// enableFeaturesFromFlags turns on or off feature flags
|
||||
// e.g. portainer --feat open-amt --feat fdo=true ... (defaults to true)
|
||||
// note, settings are persisted to the DB. To turn off `--feat open-amt=false`
|
||||
func enableFeaturesFromFlags(dataStore dataservices.DataStore, flags *portainer.CLIFlags) error {
|
||||
settings, err := dataStore.Settings().Settings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.FeatureFlagSettings == nil {
|
||||
settings.FeatureFlagSettings = make(map[portainer.Feature]bool)
|
||||
}
|
||||
|
||||
// loop through feature flags to check if they are supported
|
||||
for _, feat := range *flags.FeatureFlags {
|
||||
var correspondingFeature *portainer.Feature
|
||||
for i, supportedFeat := range portainer.SupportedFeatureFlags {
|
||||
if strings.EqualFold(feat.Name, string(supportedFeat)) {
|
||||
correspondingFeature = &portainer.SupportedFeatureFlags[i]
|
||||
}
|
||||
}
|
||||
|
||||
if correspondingFeature == nil {
|
||||
return fmt.Errorf("unknown feature flag '%s'", feat.Name)
|
||||
}
|
||||
|
||||
featureState, err := strconv.ParseBool(feat.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feature flag's '%s' value should be true or false", feat.Name)
|
||||
}
|
||||
|
||||
if featureState {
|
||||
logrus.Printf("Feature %v : on", *correspondingFeature)
|
||||
} else {
|
||||
logrus.Printf("Feature %v : off", *correspondingFeature)
|
||||
}
|
||||
|
||||
settings.FeatureFlagSettings[*correspondingFeature] = featureState
|
||||
}
|
||||
|
||||
return dataStore.Settings().UpdateSettings(settings)
|
||||
}
|
||||
|
||||
func loadAndParseKeyPair(fileService portainer.FileService, signatureService portainer.DigitalSignatureService) error {
|
||||
@@ -280,7 +357,6 @@ func loadAndParseKeyPair(fileService portainer.FileService, signatureService por
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return signatureService.ParseKeyPair(private, public)
|
||||
}
|
||||
|
||||
@@ -289,149 +365,243 @@ func generateAndStoreKeyPair(fileService portainer.FileService, signatureService
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
privateHeader, publicHeader := signatureService.PEMHeaders()
|
||||
|
||||
return fileService.StoreKeyPair(private, public, privateHeader, publicHeader)
|
||||
}
|
||||
|
||||
func initKeyPair(fileService portainer.FileService, signatureService portainer.DigitalSignatureService) error {
|
||||
existingKeyPair, err := fileService.KeyPairFilesExist()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed checking for existing key pair")
|
||||
logrus.Fatalf("Failed checking for existing key pair: %v", err)
|
||||
}
|
||||
|
||||
if existingKeyPair {
|
||||
return loadAndParseKeyPair(fileService, signatureService)
|
||||
}
|
||||
|
||||
return generateAndStoreKeyPair(fileService, signatureService)
|
||||
}
|
||||
|
||||
// dbSecretPath build the path to the file that contains the db encryption
|
||||
// secret. Normally in Docker this is built from the static path inside
|
||||
// /run/secrets for example: /run/secrets/<keyFilenameFlag> but for ease of
|
||||
// use outside Docker it also accepts an absolute path
|
||||
func dbSecretPath(keyFilenameFlag string) string {
|
||||
if path.IsAbs(keyFilenameFlag) {
|
||||
return keyFilenameFlag
|
||||
func createTLSSecuredEndpoint(flags *portainer.CLIFlags, dataStore dataservices.DataStore, snapshotService portainer.SnapshotService) error {
|
||||
tlsConfiguration := portainer.TLSConfiguration{
|
||||
TLS: *flags.TLS,
|
||||
TLSSkipVerify: *flags.TLSSkipVerify,
|
||||
}
|
||||
return path.Join("/run/secrets", keyFilenameFlag)
|
||||
|
||||
if *flags.TLS {
|
||||
tlsConfiguration.TLSCACertPath = *flags.TLSCacert
|
||||
tlsConfiguration.TLSCertPath = *flags.TLSCert
|
||||
tlsConfiguration.TLSKeyPath = *flags.TLSKey
|
||||
} else if !*flags.TLS && *flags.TLSSkipVerify {
|
||||
tlsConfiguration.TLS = true
|
||||
}
|
||||
|
||||
endpointID := dataStore.Endpoint().GetNextIdentifier()
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: portainer.EndpointID(endpointID),
|
||||
Name: "primary",
|
||||
URL: *flags.EndpointURL,
|
||||
GroupID: portainer.EndpointGroupID(1),
|
||||
Type: portainer.DockerEnvironment,
|
||||
TLSConfig: tlsConfiguration,
|
||||
UserAccessPolicies: portainer.UserAccessPolicies{},
|
||||
TeamAccessPolicies: portainer.TeamAccessPolicies{},
|
||||
TagIDs: []portainer.TagID{},
|
||||
Status: portainer.EndpointStatusUp,
|
||||
Snapshots: []portainer.DockerSnapshot{},
|
||||
Kubernetes: portainer.KubernetesDefault(),
|
||||
|
||||
SecuritySettings: portainer.EndpointSecuritySettings{
|
||||
AllowVolumeBrowserForRegularUsers: false,
|
||||
EnableHostManagementFeatures: false,
|
||||
|
||||
AllowSysctlSettingForRegularUsers: true,
|
||||
AllowBindMountsForRegularUsers: true,
|
||||
AllowPrivilegedModeForRegularUsers: true,
|
||||
AllowHostNamespaceForRegularUsers: true,
|
||||
AllowContainerCapabilitiesForRegularUsers: true,
|
||||
AllowDeviceMappingForRegularUsers: true,
|
||||
AllowStackManagementForRegularUsers: true,
|
||||
},
|
||||
}
|
||||
|
||||
if strings.HasPrefix(endpoint.URL, "tcp://") {
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsConfiguration.TLSCACertPath, tlsConfiguration.TLSCertPath, tlsConfiguration.TLSKeyPath, tlsConfiguration.TLSSkipVerify)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
agentOnDockerEnvironment, err := client.ExecutePingOperation(endpoint.URL, tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if agentOnDockerEnvironment {
|
||||
endpoint.Type = portainer.AgentOnDockerEnvironment
|
||||
}
|
||||
}
|
||||
|
||||
err := snapshotService.SnapshotEndpoint(endpoint)
|
||||
if err != nil {
|
||||
logrus.Printf("http error: environment snapshot error (environment=%s, URL=%s) (err=%s)\n", endpoint.Name, endpoint.URL, err)
|
||||
}
|
||||
|
||||
return dataStore.Endpoint().Create(endpoint)
|
||||
}
|
||||
|
||||
func createUnsecuredEndpoint(endpointURL string, dataStore dataservices.DataStore, snapshotService portainer.SnapshotService) error {
|
||||
if strings.HasPrefix(endpointURL, "tcp://") {
|
||||
_, err := client.ExecutePingOperation(endpointURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
endpointID := dataStore.Endpoint().GetNextIdentifier()
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: portainer.EndpointID(endpointID),
|
||||
Name: "primary",
|
||||
URL: endpointURL,
|
||||
GroupID: portainer.EndpointGroupID(1),
|
||||
Type: portainer.DockerEnvironment,
|
||||
TLSConfig: portainer.TLSConfiguration{},
|
||||
UserAccessPolicies: portainer.UserAccessPolicies{},
|
||||
TeamAccessPolicies: portainer.TeamAccessPolicies{},
|
||||
TagIDs: []portainer.TagID{},
|
||||
Status: portainer.EndpointStatusUp,
|
||||
Snapshots: []portainer.DockerSnapshot{},
|
||||
Kubernetes: portainer.KubernetesDefault(),
|
||||
|
||||
SecuritySettings: portainer.EndpointSecuritySettings{
|
||||
AllowVolumeBrowserForRegularUsers: false,
|
||||
EnableHostManagementFeatures: false,
|
||||
|
||||
AllowSysctlSettingForRegularUsers: true,
|
||||
AllowBindMountsForRegularUsers: true,
|
||||
AllowPrivilegedModeForRegularUsers: true,
|
||||
AllowHostNamespaceForRegularUsers: true,
|
||||
AllowContainerCapabilitiesForRegularUsers: true,
|
||||
AllowDeviceMappingForRegularUsers: true,
|
||||
AllowStackManagementForRegularUsers: true,
|
||||
},
|
||||
}
|
||||
|
||||
err := snapshotService.SnapshotEndpoint(endpoint)
|
||||
if err != nil {
|
||||
logrus.Printf("http error: environment snapshot error (environment=%s, URL=%s) (err=%s)\n", endpoint.Name, endpoint.URL, err)
|
||||
}
|
||||
|
||||
return dataStore.Endpoint().Create(endpoint)
|
||||
}
|
||||
|
||||
func initEndpoint(flags *portainer.CLIFlags, dataStore dataservices.DataStore, snapshotService portainer.SnapshotService) error {
|
||||
if *flags.EndpointURL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
endpoints, err := dataStore.Endpoint().Endpoints()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(endpoints) > 0 {
|
||||
logrus.Println("Instance already has defined environments. Skipping the environment defined via CLI.")
|
||||
return nil
|
||||
}
|
||||
|
||||
if *flags.TLS || *flags.TLSSkipVerify {
|
||||
return createTLSSecuredEndpoint(flags, dataStore, snapshotService)
|
||||
}
|
||||
return createUnsecuredEndpoint(*flags.EndpointURL, dataStore, snapshotService)
|
||||
}
|
||||
|
||||
func loadEncryptionSecretKey(keyfilename string) []byte {
|
||||
content, err := os.ReadFile(keyfilename)
|
||||
content, err := os.ReadFile(path.Join("/run/secrets", keyfilename))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Info().Str("filename", keyfilename).Msg("encryption key file not present")
|
||||
logrus.Printf("Encryption key file `%s` not present", keyfilename)
|
||||
} else {
|
||||
log.Info().Err(err).Msg("error reading encryption key file")
|
||||
logrus.Printf("Error reading encryption key file: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// return a 32 byte hash of the secret (required for AES)
|
||||
// fips compliant version of this is not implemented in -ce
|
||||
hash := sha256.Sum256(content)
|
||||
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
shutdownCtx, shutdownTrigger := context.WithCancel(context.Background())
|
||||
|
||||
if flags.FeatureFlags != nil {
|
||||
featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags)
|
||||
}
|
||||
|
||||
trustedOrigins := []string{}
|
||||
if *flags.TrustedOrigins != "" {
|
||||
// validate if the trusted origins are valid urls
|
||||
for origin := range strings.SplitSeq(*flags.TrustedOrigins, ",") {
|
||||
if !validate.IsTrustedOrigin(origin) {
|
||||
log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.")
|
||||
}
|
||||
|
||||
trustedOrigins = append(trustedOrigins, origin)
|
||||
}
|
||||
}
|
||||
|
||||
// -ce can not ever be run in FIPS mode
|
||||
fips.InitFIPS(false)
|
||||
|
||||
fileService := initFileService(*flags.Data)
|
||||
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
|
||||
encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName)
|
||||
if encryptionKey == nil {
|
||||
log.Info().Msg("proceeding without encryption key")
|
||||
logrus.Println("Proceeding without encryption key")
|
||||
}
|
||||
|
||||
dataStore := initDataStore(flags, encryptionKey, fileService, shutdownCtx)
|
||||
|
||||
if err := dataStore.CheckCurrentEdition(); err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
// check if the db schema version matches with server version
|
||||
if !checkDBSchemaServerVersionMatch(dataStore, portainer.APIVersion, int(portainer.Edition)) {
|
||||
log.Fatal().Msg("The database schema version does not align with the server version. Please consider reverting to the previous server version or addressing the database migration issue.")
|
||||
}
|
||||
|
||||
instanceID, err := dataStore.Version().InstanceID()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed getting instance id")
|
||||
logrus.Fatalf("Failed getting instance id: %v", err)
|
||||
}
|
||||
|
||||
apiKeyService := initAPIKeyService(dataStore)
|
||||
|
||||
settings, err := dataStore.Settings().Settings()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
jwtService, err := initJWTService(settings.UserSessionTimeout, dataStore)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing JWT service")
|
||||
logrus.Fatalf("Failed initializing JWT service: %v", err)
|
||||
}
|
||||
|
||||
ldapService := ldap.Service{}
|
||||
|
||||
oauthService := oauth.NewService()
|
||||
|
||||
gitService := git.NewService(shutdownCtx)
|
||||
|
||||
// Setting insecureSkipVerify to true to preserve the old behaviour.
|
||||
openAMTService := openamt.NewService(true)
|
||||
|
||||
cryptoService := crypto.Service{}
|
||||
|
||||
signatureService := initDigitalSignatureService()
|
||||
|
||||
edgeStacksService := edgestacks.NewService(dataStore)
|
||||
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.TLSCert, *flags.TLSKey, fileService, dataStore, shutdownTrigger)
|
||||
err = enableFeaturesFromFlags(dataStore, flags)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
logrus.Fatalf("Failed enabling feature flag: %v", err)
|
||||
}
|
||||
|
||||
ldapService := initLDAPService()
|
||||
oauthService := initOAuthService()
|
||||
gitService := initGitService()
|
||||
|
||||
openAMTService := openamt.NewService()
|
||||
|
||||
cryptoService := initCryptoService()
|
||||
|
||||
digitalSignatureService := initDigitalSignatureService()
|
||||
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.SSLCert, *flags.SSLKey, fileService, dataStore, shutdownTrigger)
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
sslSettings, err := sslService.GetSSLSettings()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to get SSL settings")
|
||||
logrus.Fatalf("Failed to get ssl settings: %s", err)
|
||||
}
|
||||
|
||||
if err := initKeyPair(fileService, signatureService); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing key pair")
|
||||
}
|
||||
|
||||
reverseTunnelService := chisel.NewService(dataStore, shutdownCtx, fileService)
|
||||
|
||||
dockerClientFactory := dockerclient.NewClientFactory(signatureService, reverseTunnelService)
|
||||
|
||||
kubernetesClientFactory, err := kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
|
||||
err = initKeyPair(fileService, digitalSignatureService)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing Kubernetes Client Factory service")
|
||||
logrus.Fatalf("Failed initializing key pair: %v", err)
|
||||
}
|
||||
|
||||
reverseTunnelService := chisel.NewService(dataStore, shutdownCtx)
|
||||
|
||||
dockerClientFactory := initDockerClientFactory(digitalSignatureService, reverseTunnelService)
|
||||
kubernetesClientFactory := initKubernetesClientFactory(digitalSignatureService, reverseTunnelService, instanceID, dataStore)
|
||||
|
||||
snapshotService, err := initSnapshotService(*flags.SnapshotInterval, dataStore, dockerClientFactory, kubernetesClientFactory, shutdownCtx)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed initializing snapshot service: %v", err)
|
||||
}
|
||||
snapshotService.Start()
|
||||
|
||||
authorizationService := authorization.NewService(dataStore)
|
||||
authorizationService.K8sClientFactory = kubernetesClientFactory
|
||||
|
||||
@@ -439,60 +609,55 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
|
||||
kubeClusterAccessService := kubernetes.NewKubeClusterAccessService(*flags.BaseURL, *flags.AddrHTTPS, sslSettings.CertPath)
|
||||
|
||||
proxyManager := proxy.NewManager(kubernetesClientFactory)
|
||||
proxyManager := proxy.NewManager(dataStore, digitalSignatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService)
|
||||
|
||||
reverseTunnelService.ProxyManager = proxyManager
|
||||
|
||||
dockerConfigPath := fileService.GetDockerConfigPath()
|
||||
|
||||
composeDeployer := compose.NewComposeDeployer()
|
||||
composeStackManager := initComposeStackManager(*flags.Assets, dockerConfigPath, reverseTunnelService, proxyManager)
|
||||
|
||||
composeStackManager := exec.NewComposeStackManager(composeDeployer, proxyManager, dataStore)
|
||||
|
||||
swarmStackManager, err := exec.NewSwarmStackManager(*flags.Assets, dockerConfigPath, signatureService, fileService, reverseTunnelService, dataStore)
|
||||
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, digitalSignatureService, fileService, reverseTunnelService, dataStore)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing swarm stack manager")
|
||||
logrus.Fatalf("Failed initializing swarm stack manager: %v", err)
|
||||
}
|
||||
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager)
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, digitalSignatureService, proxyManager, *flags.Assets)
|
||||
|
||||
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory)
|
||||
pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore))
|
||||
pendingActionsService.RegisterHandler(actions.DeletePortainerK8sRegistrySecrets, handlers.NewHandlerDeleteRegistrySecrets(authorizationService, dataStore, kubernetesClientFactory))
|
||||
pendingActionsService.RegisterHandler(actions.PostInitMigrateEnvironment, handlers.NewHandlerPostInitMigrateEnvironment(authorizationService, dataStore, kubernetesClientFactory, dockerClientFactory, *flags.Assets, kubernetesDeployer))
|
||||
|
||||
snapshotService, err := initSnapshotService(*flags.SnapshotInterval, dataStore, dockerClientFactory, kubernetesClientFactory, shutdownCtx, pendingActionsService)
|
||||
helmPackageManager, err := initHelmPackageManager(*flags.Assets)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing snapshot service")
|
||||
logrus.Fatalf("Failed initializing helm package manager: %v", err)
|
||||
}
|
||||
|
||||
snapshotService.Start()
|
||||
|
||||
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService, jwtService)
|
||||
|
||||
helmPackageManager, err := initHelmPackageManager()
|
||||
err = edge.LoadEdgeJobs(dataStore, reverseTunnelService)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing helm package manager")
|
||||
logrus.Fatalf("Failed loading edge jobs from database: %v", err)
|
||||
}
|
||||
|
||||
applicationStatus := initStatus(instanceID)
|
||||
|
||||
// channel to control when the admin user is created
|
||||
adminCreationDone := make(chan struct{}, 1)
|
||||
demoService := demo.NewService()
|
||||
if *flags.DemoEnvironment {
|
||||
err := demoService.Init(dataStore, cryptoService)
|
||||
if err != nil {
|
||||
log.Fatalf("failed initializing demo environment: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
go endpointutils.InitEndpoint(shutdownCtx, adminCreationDone, flags, dataStore, snapshotService)
|
||||
err = initEndpoint(flags, dataStore, snapshotService)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed initializing environment: %v", err)
|
||||
}
|
||||
|
||||
adminPasswordHash := ""
|
||||
|
||||
if *flags.AdminPasswordFile != "" {
|
||||
content, err := fileService.GetFileContent(*flags.AdminPasswordFile, "")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed getting admin password file")
|
||||
logrus.Fatalf("Failed getting admin password file: %v", err)
|
||||
}
|
||||
|
||||
adminPasswordHash, err = cryptoService.Hash(strings.TrimSuffix(string(content), "\n"))
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed hashing admin password")
|
||||
logrus.Fatalf("Failed hashing admin password: %v", err)
|
||||
}
|
||||
} else if *flags.AdminPassword != "" {
|
||||
adminPasswordHash = *flags.AdminPassword
|
||||
@@ -501,76 +666,38 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
if adminPasswordHash != "" {
|
||||
users, err := dataStore.User().UsersByRole(portainer.AdministratorRole)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed getting admin user")
|
||||
logrus.Fatalf("Failed getting admin user: %v", err)
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
log.Info().Msg("created admin user with the given password.")
|
||||
|
||||
logrus.Println("Created admin user with the given password.")
|
||||
user := &portainer.User{
|
||||
Username: "admin",
|
||||
Role: portainer.AdministratorRole,
|
||||
Password: adminPasswordHash,
|
||||
}
|
||||
|
||||
if err := dataStore.User().Create(user); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed creating admin user")
|
||||
err := dataStore.User().Create(user)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed creating admin user: %v", err)
|
||||
}
|
||||
|
||||
// notify the admin user is created, the endpoint initialization can start
|
||||
adminCreationDone <- struct{}{}
|
||||
} else {
|
||||
log.Info().Msg("instance already has an administrator user defined, skipping admin password related flags.")
|
||||
logrus.Println("Instance already has an administrator user defined. Skipping admin password related flags.")
|
||||
}
|
||||
}
|
||||
|
||||
if err := reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed starting tunnel server")
|
||||
}
|
||||
|
||||
scheduler := scheduler.NewScheduler(shutdownCtx)
|
||||
stackDeployer := deployments.NewStackDeployer(swarmStackManager, composeStackManager, kubernetesDeployer, dockerClientFactory, dataStore)
|
||||
if err := deployments.StartStackSchedules(scheduler, stackDeployer, dataStore, gitService); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to start stack scheduler")
|
||||
err = reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Failed starting tunnel server: %v", err)
|
||||
}
|
||||
|
||||
sslDBSettings, err := dataStore.SSLSettings().Settings()
|
||||
if err != nil {
|
||||
log.Fatal().Msg("failed to fetch SSL settings from DB")
|
||||
logrus.Fatalf("Failed to fetch ssl settings from DB")
|
||||
}
|
||||
|
||||
platformService, err := platform.NewService(dataStore)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing platform service")
|
||||
}
|
||||
|
||||
upgradeService, err := upgrade.NewService(
|
||||
*flags.Assets,
|
||||
kubernetesClientFactory,
|
||||
dockerClientFactory,
|
||||
composeStackManager,
|
||||
dataStore,
|
||||
fileService,
|
||||
stackDeployer,
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing upgrade service")
|
||||
}
|
||||
|
||||
// Our normal migrations run as part of the database initialization
|
||||
// but some more complex migrations require access to a kubernetes or docker
|
||||
// client. Therefore we run a separate migration process just before
|
||||
// starting the server.
|
||||
postInitMigrator := postinit.NewPostInitMigrator(
|
||||
kubernetesClientFactory,
|
||||
dockerClientFactory,
|
||||
dataStore,
|
||||
*flags.Assets,
|
||||
kubernetesDeployer,
|
||||
)
|
||||
if err := postInitMigrator.PostInitMigrate(); err != nil {
|
||||
log.Fatal().Err(err).Msg("failure during post init migrations")
|
||||
}
|
||||
scheduler := scheduler.NewScheduler(shutdownCtx)
|
||||
stackDeployer := stacks.NewStackDeployer(swarmStackManager, composeStackManager, kubernetesDeployer)
|
||||
stacks.StartStackSchedules(scheduler, stackDeployer, dataStore, gitService)
|
||||
|
||||
return &http.Server{
|
||||
AuthorizationService: authorizationService,
|
||||
@@ -578,17 +705,15 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
Status: applicationStatus,
|
||||
BindAddress: *flags.Addr,
|
||||
BindAddressHTTPS: *flags.AddrHTTPS,
|
||||
CSP: *flags.CSP,
|
||||
HTTPEnabled: sslDBSettings.HTTPEnabled,
|
||||
AssetsPath: *flags.Assets,
|
||||
DataStore: dataStore,
|
||||
EdgeStacksService: edgeStacksService,
|
||||
SwarmStackManager: swarmStackManager,
|
||||
ComposeStackManager: composeStackManager,
|
||||
KubernetesDeployer: kubernetesDeployer,
|
||||
HelmPackageManager: helmPackageManager,
|
||||
APIKeyService: apiKeyService,
|
||||
CryptoService: cryptoService,
|
||||
APIKeyService: apiKeyService,
|
||||
JWTService: jwtService,
|
||||
FileService: fileService,
|
||||
LDAPService: ldapService,
|
||||
@@ -598,7 +723,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
ProxyManager: proxyManager,
|
||||
KubernetesTokenCacheManager: kubernetesTokenCacheManager,
|
||||
KubeClusterAccessService: kubeClusterAccessService,
|
||||
SignatureService: signatureService,
|
||||
SignatureService: digitalSignatureService,
|
||||
SnapshotService: snapshotService,
|
||||
SSLService: sslService,
|
||||
DockerClientFactory: dockerClientFactory,
|
||||
@@ -607,39 +732,19 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
ShutdownCtx: shutdownCtx,
|
||||
ShutdownTrigger: shutdownTrigger,
|
||||
StackDeployer: stackDeployer,
|
||||
UpgradeService: upgradeService,
|
||||
AdminCreationDone: adminCreationDone,
|
||||
PendingActionsService: pendingActionsService,
|
||||
PlatformService: platformService,
|
||||
PullLimitCheckDisabled: *flags.PullLimitCheckDisabled,
|
||||
TrustedOrigins: trustedOrigins,
|
||||
DemoService: demoService,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
logs.ConfigureLogger()
|
||||
logs.SetLoggingMode("PRETTY")
|
||||
|
||||
flags := initCLI()
|
||||
|
||||
logs.SetLoggingLevel(*flags.LogLevel)
|
||||
logs.SetLoggingMode(*flags.LogMode)
|
||||
configureLogger()
|
||||
|
||||
for {
|
||||
server := buildServer(flags)
|
||||
|
||||
log.Info().
|
||||
Str("version", portainer.APIVersion).
|
||||
Str("build_number", build.BuildNumber).
|
||||
Str("image_tag", build.ImageTag).
|
||||
Str("nodejs_version", build.NodejsVersion).
|
||||
Str("pnpm_version", build.PnpmVersion).
|
||||
Str("webpack_version", build.WebpackVersion).
|
||||
Str("go_version", build.GoVersion).
|
||||
Msg("starting Portainer")
|
||||
|
||||
logrus.Printf("[INFO] [cmd,main] Starting Portainer version %s\n", portainer.APIVersion)
|
||||
err := server.Start()
|
||||
|
||||
log.Info().Err(err).Msg("HTTP server exited")
|
||||
logrus.Printf("[INFO] [cmd,main] Http server exited: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,57 +1,110 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/cli"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
const secretFileName = "secret.txt"
|
||||
type mockKingpinSetting string
|
||||
|
||||
func createPasswordFile(t *testing.T, secretPath, password string) string {
|
||||
err := os.WriteFile(secretPath, []byte(password), 0600)
|
||||
require.NoError(t, err)
|
||||
return secretPath
|
||||
func (m mockKingpinSetting) SetValue(value kingpin.Value) {
|
||||
value.Set(string(m))
|
||||
}
|
||||
|
||||
func TestLoadEncryptionSecretKey(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
secretPath := path.Join(tempDir, secretFileName)
|
||||
func Test_enableFeaturesFromFlags(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
// first pointing to file that does not exist, gives nil hash (no encryption)
|
||||
encryptionKey := loadEncryptionSecretKey(secretPath)
|
||||
require.Nil(t, encryptionKey)
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
// point to a directory instead of a file
|
||||
encryptionKey = loadEncryptionSecretKey(tempDir)
|
||||
require.Nil(t, encryptionKey)
|
||||
|
||||
password := "portainer@1234"
|
||||
createPasswordFile(t, secretPath, password)
|
||||
|
||||
encryptionKey = loadEncryptionSecretKey(secretPath)
|
||||
require.NotNil(t, encryptionKey)
|
||||
// should be 32 bytes for aes256 encryption
|
||||
require.Len(t, encryptionKey, 32)
|
||||
}
|
||||
|
||||
func TestDBSecretPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
keyFilenameFlag string
|
||||
expected string
|
||||
featureFlag string
|
||||
isSupported bool
|
||||
}{
|
||||
{keyFilenameFlag: "secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"},
|
||||
{keyFilenameFlag: "/run/secrets/secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "./secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"},
|
||||
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/secrets/foo/bar/secret.txt"},
|
||||
{"test", false},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s succeeds:%v", test.featureFlag, test.isSupported), func(t *testing.T) {
|
||||
mockKingpinSetting := mockKingpinSetting(test.featureFlag)
|
||||
flags := &portainer.CLIFlags{FeatureFlags: cli.BoolPairs(mockKingpinSetting)}
|
||||
err := enableFeaturesFromFlags(store, flags)
|
||||
if test.isSupported {
|
||||
is.NoError(err)
|
||||
} else {
|
||||
is.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.expected, dbSecretPath(test.keyFilenameFlag))
|
||||
}
|
||||
t.Run("passes for all supported feature flags", func(t *testing.T) {
|
||||
for _, flag := range portainer.SupportedFeatureFlags {
|
||||
mockKingpinSetting := mockKingpinSetting(flag)
|
||||
flags := &portainer.CLIFlags{FeatureFlags: cli.BoolPairs(mockKingpinSetting)}
|
||||
err := enableFeaturesFromFlags(store, flags)
|
||||
is.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const FeatTest portainer.Feature = "optional-test"
|
||||
|
||||
func optionalFunc(dataStore dataservices.DataStore) string {
|
||||
|
||||
// TODO: this is a code smell - finding out if a feature flag is enabled should not require having access to the store, and asking for a settings obj.
|
||||
// ideally, the `if` should look more like:
|
||||
// if featureflags.FlagEnabled(FeatTest) {}
|
||||
settings, err := dataStore.Settings().Settings()
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
if settings.FeatureFlagSettings[FeatTest] {
|
||||
return "enabled"
|
||||
}
|
||||
return "disabled"
|
||||
}
|
||||
|
||||
func Test_optionalFeature(t *testing.T) {
|
||||
portainer.SupportedFeatureFlags = append(portainer.SupportedFeatureFlags, FeatTest)
|
||||
|
||||
is := assert.New(t)
|
||||
|
||||
_, store, teardown := datastore.MustNewTestStore(true, true)
|
||||
defer teardown()
|
||||
|
||||
// Enable the test feature
|
||||
t.Run(fmt.Sprintf("%s succeeds:%v", FeatTest, true), func(t *testing.T) {
|
||||
mockKingpinSetting := mockKingpinSetting(FeatTest)
|
||||
flags := &portainer.CLIFlags{FeatureFlags: cli.BoolPairs(mockKingpinSetting)}
|
||||
err := enableFeaturesFromFlags(store, flags)
|
||||
is.NoError(err)
|
||||
is.Equal("enabled", optionalFunc(store))
|
||||
})
|
||||
|
||||
// Same store, so the feature flag should still be enabled
|
||||
t.Run(fmt.Sprintf("%s succeeds:%v", FeatTest, true), func(t *testing.T) {
|
||||
is.Equal("enabled", optionalFunc(store))
|
||||
})
|
||||
|
||||
// disable the test feature
|
||||
t.Run(fmt.Sprintf("%s succeeds:%v", FeatTest, true), func(t *testing.T) {
|
||||
mockKingpinSetting := mockKingpinSetting(FeatTest + "=false")
|
||||
flags := &portainer.CLIFlags{FeatureFlags: cli.BoolPairs(mockKingpinSetting)}
|
||||
err := enableFeaturesFromFlags(store, flags)
|
||||
is.NoError(err)
|
||||
is.Equal("disabled", optionalFunc(store))
|
||||
})
|
||||
|
||||
// Same store, so feature flag should still be disabled
|
||||
t.Run(fmt.Sprintf("%s succeeds:%v", FeatTest, true), func(t *testing.T) {
|
||||
is.Equal("disabled", optionalFunc(store))
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
// Package concurrent provides utilities for running multiple functions concurrently in Go.
|
||||
// For example, many kubernetes calls can take a while to fulfill. Oftentimes in Portainer
|
||||
// we need to get a list of objects from multiple kubernetes REST APIs. We can often call these
|
||||
// apis concurrently to speed up the response time.
|
||||
// This package provides a clean way to do just that.
|
||||
//
|
||||
// Examples:
|
||||
// The ConfigMaps and Secrets function converted using concurrent.Run.
|
||||
/*
|
||||
|
||||
// GetConfigMapsAndSecrets gets all the ConfigMaps AND all the Secrets for a
|
||||
// given namespace in a k8s endpoint. The result is a list of both config maps
|
||||
// and secrets. The IsSecret boolean property indicates if a given struct is a
|
||||
// secret or configmap.
|
||||
func (kcl *KubeClient) GetConfigMapsAndSecrets(namespace string) ([]models.K8sConfigMapOrSecret, error) {
|
||||
|
||||
// use closures to capture the current kube client and namespace by declaring wrapper functions
|
||||
// that match the interface signature for concurrent.Func
|
||||
|
||||
listConfigMaps := func(ctx context.Context) (any, error) {
|
||||
return kcl.cli.CoreV1().ConfigMaps(namespace).List(context.Background(), meta.ListOptions{})
|
||||
}
|
||||
|
||||
listSecrets := func(ctx context.Context) (any, error) {
|
||||
return kcl.cli.CoreV1().Secrets(namespace).List(context.Background(), meta.ListOptions{})
|
||||
}
|
||||
|
||||
// run the functions concurrently and wait for results. We can also pass in a context to cancel.
|
||||
// e.g. Deadline timer.
|
||||
results, err := concurrent.Run(context.TODO(), listConfigMaps, listSecrets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var configMapList *core.ConfigMapList
|
||||
var secretList *core.SecretList
|
||||
for _, r := range results {
|
||||
switch v := r.Result.(type) {
|
||||
case *core.ConfigMapList:
|
||||
configMapList = v
|
||||
case *core.SecretList:
|
||||
secretList = v
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Applications
|
||||
var combined []models.K8sConfigMapOrSecret
|
||||
for _, m := range configMapList.Items {
|
||||
var cm models.K8sConfigMapOrSecret
|
||||
cm.UID = string(m.UID)
|
||||
cm.Name = m.Name
|
||||
cm.Namespace = m.Namespace
|
||||
cm.Annotations = m.Annotations
|
||||
cm.Data = m.Data
|
||||
cm.CreationDate = m.CreationTimestamp.Time.UTC().Format(time.RFC3339)
|
||||
combined = append(combined, cm)
|
||||
}
|
||||
|
||||
for _, s := range secretList.Items {
|
||||
var secret models.K8sConfigMapOrSecret
|
||||
secret.UID = string(s.UID)
|
||||
secret.Name = s.Name
|
||||
secret.Namespace = s.Namespace
|
||||
secret.Annotations = s.Annotations
|
||||
secret.Data = msbToMss(s.Data)
|
||||
secret.CreationDate = s.CreationTimestamp.Time.UTC().Format(time.RFC3339)
|
||||
secret.IsSecret = true
|
||||
secret.SecretType = string(s.Type)
|
||||
combined = append(combined, secret)
|
||||
}
|
||||
|
||||
return combined, nil
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
package concurrent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Result contains the result and any error returned from running a client task function
|
||||
type Result struct {
|
||||
Result any // the result of running the task function
|
||||
Err error // any error that occurred while running the task function
|
||||
}
|
||||
|
||||
// Func is a function returns a result or error
|
||||
type Func func(ctx context.Context) (any, error)
|
||||
|
||||
// Run runs a list of functions returns the results
|
||||
func Run(ctx context.Context, maxConcurrency int, tasks ...Func) ([]Result, error) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
resultsChan := make(chan Result, len(tasks))
|
||||
taskChan := make(chan Func, len(tasks))
|
||||
|
||||
localCtx, cancelCtx := context.WithCancel(ctx)
|
||||
defer cancelCtx()
|
||||
|
||||
runTask := func() {
|
||||
defer wg.Done()
|
||||
|
||||
for fn := range taskChan {
|
||||
result, err := fn(localCtx)
|
||||
resultsChan <- Result{Result: result, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// Set maxConcurrency to the number of tasks if zero or negative
|
||||
if maxConcurrency <= 0 {
|
||||
maxConcurrency = len(tasks)
|
||||
}
|
||||
|
||||
// Start worker goroutines
|
||||
for range maxConcurrency {
|
||||
wg.Add(1)
|
||||
go runTask()
|
||||
}
|
||||
|
||||
// Add tasks to the task channel
|
||||
for _, fn := range tasks {
|
||||
taskChan <- fn
|
||||
}
|
||||
|
||||
// Close the task channel to signal workers to stop when all tasks are done
|
||||
close(taskChan)
|
||||
|
||||
// Wait for all workers to complete
|
||||
wg.Wait()
|
||||
close(resultsChan)
|
||||
|
||||
// Collect the results and cancel on error
|
||||
results := make([]Result, 0, len(tasks))
|
||||
for r := range resultsChan {
|
||||
if r.Err != nil {
|
||||
cancelCtx()
|
||||
|
||||
return nil, r.Err
|
||||
}
|
||||
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
@@ -4,36 +4,10 @@ import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type ReadTransaction interface {
|
||||
GetObject(bucketName string, key []byte, object any) error
|
||||
GetRawBytes(bucketName string, key []byte) ([]byte, error)
|
||||
GetAll(bucketName string, obj any, append func(o any) (any, error)) error
|
||||
GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj any, append func(o any) (any, error)) error
|
||||
KeyExists(bucketName string, key []byte) (bool, error)
|
||||
}
|
||||
|
||||
type Transaction interface {
|
||||
ReadTransaction
|
||||
|
||||
SetServiceName(bucketName string) error
|
||||
UpdateObject(bucketName string, key []byte, object any) error
|
||||
DeleteObject(bucketName string, key []byte) error
|
||||
CreateObject(bucketName string, fn func(uint64) (int, any)) error
|
||||
CreateObjectWithId(bucketName string, id int, obj any) error
|
||||
CreateObjectWithStringId(bucketName string, id []byte, obj any) error
|
||||
DeleteAllObjects(bucketName string, obj any, matching func(o any) (id int, ok bool)) error
|
||||
GetNextIdentifier(bucketName string) int
|
||||
}
|
||||
|
||||
type Connection interface {
|
||||
Transaction
|
||||
|
||||
Open() error
|
||||
Close() error
|
||||
|
||||
UpdateTx(fn func(Transaction) error) error
|
||||
ViewTx(fn func(Transaction) error) error
|
||||
|
||||
// write the db contents to filename as json (the schema needs defining)
|
||||
ExportRaw(filename string) error
|
||||
|
||||
@@ -42,15 +16,25 @@ type Connection interface {
|
||||
GetDatabaseFileName() string
|
||||
GetDatabaseFilePath() string
|
||||
GetStorePath() string
|
||||
GetDatabaseFileSize() (int64, error)
|
||||
|
||||
IsEncryptedStore() bool
|
||||
NeedsEncryptionMigration() (bool, error)
|
||||
SetEncrypted(encrypted bool)
|
||||
|
||||
BackupMetadata() (map[string]any, error)
|
||||
RestoreMetadata(s map[string]any) error
|
||||
|
||||
UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error
|
||||
SetServiceName(bucketName string) error
|
||||
GetObject(bucketName string, key []byte, object interface{}) error
|
||||
UpdateObject(bucketName string, key []byte, object interface{}) error
|
||||
DeleteObject(bucketName string, key []byte) error
|
||||
DeleteAllObjects(bucketName string, matching func(o interface{}) (id int, ok bool)) error
|
||||
GetNextIdentifier(bucketName string) int
|
||||
CreateObject(bucketName string, fn func(uint64) (int, interface{})) error
|
||||
CreateObjectWithId(bucketName string, id int, obj interface{}) error
|
||||
CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error
|
||||
CreateObjectWithSetSequence(bucketName string, id int, obj interface{}) error
|
||||
GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
|
||||
GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
|
||||
ConvertToKey(v int) []byte
|
||||
|
||||
BackupMetadata() (map[string]interface{}, error)
|
||||
RestoreMetadata(s map[string]interface{}) error
|
||||
}
|
||||
|
||||
@@ -1,382 +1,55 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/pbkdf2"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
// Not allowed in FIPS mode
|
||||
"golang.org/x/crypto/argon2" //nolint:depguard
|
||||
"golang.org/x/crypto/scrypt" //nolint:depguard
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// AES GCM settings
|
||||
aesGcmHeader = "AES256-GCM" // The encrypted file header
|
||||
aesGcmBlockSize = 1024 * 1024 // 1MB block for aes gcm
|
||||
// NOTE: has to go with what is considered to be a simplistic in that it omits any
|
||||
// authentication of the encrypted data.
|
||||
// Person with better knowledge is welcomed to improve it.
|
||||
// sourced from https://golang.org/src/crypto/cipher/example_test.go
|
||||
|
||||
aesGcmFIPSHeader = "FIPS-AES256-GCM"
|
||||
aesGcmFIPSBlockSize = 16 * 1024 * 1024 // 16MB block for aes gcm
|
||||
var emptySalt []byte = make([]byte, 0, 0)
|
||||
|
||||
// Argon2 settings
|
||||
// Recommended settings lower memory hardware according to current OWASP recommendations
|
||||
// Considering some people run portainer on a NAS I think it's prudent not to assume we're on server grade hardware
|
||||
// https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#argon2id
|
||||
argon2MemoryCost = 12 * 1024
|
||||
argon2TimeCost = 3
|
||||
argon2Threads = 1
|
||||
argon2KeyLength = 32
|
||||
|
||||
pbkdf2Iterations = 600_000 // use recommended iterations from https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2 a little overkill for this use
|
||||
pbkdf2SaltLength = 32
|
||||
)
|
||||
|
||||
// AesEncrypt reads from input, encrypts with AES-256 and writes to output. passphrase is used to generate an encryption key
|
||||
func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
if fips.FIPSMode() {
|
||||
if err := aesEncryptGCMFIPS(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := aesEncryptGCM(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AesDecrypt reads from input, decrypts with AES-256 and returns the reader to read the decrypted content from
|
||||
func AesDecrypt(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func aesDecrypt(input io.Reader, passphrase []byte, fipsMode bool) (io.Reader, error) {
|
||||
// Read file header to determine how it was encrypted
|
||||
inputReader := bufio.NewReader(input)
|
||||
header, err := inputReader.Peek(len(aesGcmFIPSHeader))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading encrypted backup file header: %w", err)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(string(header), aesGcmFIPSHeader) {
|
||||
if !fipsMode {
|
||||
return nil, errors.New("fips encrypted file detected but fips mode is not enabled")
|
||||
}
|
||||
|
||||
reader, err := aesDecryptGCMFIPS(inputReader, passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting file: %w", err)
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(string(header), aesGcmHeader) {
|
||||
if fipsMode {
|
||||
return nil, errors.New("fips mode is enabled but non-fips encrypted file detected")
|
||||
}
|
||||
|
||||
reader, err := aesDecryptGCM(inputReader, passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting file: %w", err)
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// Use the previous decryption routine which has no header (to support older archives)
|
||||
reader, err := aesDecryptOFB(inputReader, passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting legacy file backup: %w", err)
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// aesEncryptGCM reads from input, encrypts with AES-256 and writes to output. passphrase is used to generate an encryption key.
|
||||
func aesEncryptGCM(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
// Derive key using argon2 with a random salt
|
||||
salt := make([]byte, 16) // 16 bytes salt
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := argon2.IDKey(passphrase, salt, argon2TimeCost, argon2MemoryCost, argon2Threads, 32)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate nonce
|
||||
nonce, err := NewRandomNonce(aesgcm.NonceSize())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write the header
|
||||
if _, err := output.Write([]byte(aesGcmHeader)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write nonce and salt to the output file
|
||||
if _, err := output.Write(salt); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := output.Write(nonce.Value()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffer for reading plaintext blocks
|
||||
buf := make([]byte, aesGcmBlockSize) // Adjust buffer size as needed
|
||||
ciphertext := make([]byte, len(buf)+aesgcm.Overhead())
|
||||
|
||||
// Encrypt plaintext in blocks
|
||||
for {
|
||||
n, err := io.ReadFull(input, buf)
|
||||
if n == 0 {
|
||||
break // end of plaintext input
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seal encrypts the plaintext using the nonce returning the updated slice.
|
||||
ciphertext = aesgcm.Seal(ciphertext[:0], nonce.Value(), buf[:n], nil)
|
||||
|
||||
if _, err := output.Write(ciphertext); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := nonce.Increment(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// aesDecryptGCM reads from input, decrypts with AES-256 and returns the reader to read the decrypted content from.
|
||||
func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
// Reader & verify header
|
||||
header := make([]byte, len(aesGcmHeader))
|
||||
if _, err := io.ReadFull(input, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if string(header) != aesGcmHeader {
|
||||
return nil, errors.New("invalid header")
|
||||
}
|
||||
|
||||
// Read salt
|
||||
salt := make([]byte, 16) // Salt size
|
||||
if _, err := io.ReadFull(input, salt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := argon2.IDKey(passphrase, salt, argon2TimeCost, argon2MemoryCost, argon2Threads, 32)
|
||||
|
||||
// Initialize AES cipher block
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create GCM mode with the cipher block
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read nonce from the input reader
|
||||
nonce := NewNonce(aesgcm.NonceSize())
|
||||
if err := nonce.Read(input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize a buffer to store decrypted data
|
||||
buf := bytes.Buffer{}
|
||||
plaintext := make([]byte, aesGcmBlockSize)
|
||||
|
||||
// Decrypt the ciphertext in blocks
|
||||
for {
|
||||
// Read a block of ciphertext from the input reader
|
||||
ciphertextBlock := make([]byte, aesGcmBlockSize+aesgcm.Overhead()) // Adjust block size as needed
|
||||
n, err := io.ReadFull(input, ciphertextBlock)
|
||||
if n == 0 {
|
||||
break // end of ciphertext
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt the block of ciphertext
|
||||
plaintext, err = aesgcm.Open(plaintext[:0], nonce.Value(), ciphertextBlock[:n], nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.Write(plaintext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := nonce.Increment(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
// aesEncryptGCMFIPS reads from input, encrypts with AES-256 in a fips compliant
|
||||
// way and writes to output. passphrase is used to generate an encryption key.
|
||||
func aesEncryptGCMFIPS(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
salt := make([]byte, pbkdf2SaltLength)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write the header
|
||||
if _, err := output.Write([]byte(aesGcmFIPSHeader)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write nonce and salt to the output file
|
||||
if _, err := output.Write(salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffer for reading plaintext blocks
|
||||
buf := make([]byte, aesGcmFIPSBlockSize)
|
||||
|
||||
// Encrypt plaintext in blocks
|
||||
for {
|
||||
// new random nonce for each block
|
||||
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating gcm: %w", err)
|
||||
}
|
||||
|
||||
n, err := io.ReadFull(input, buf)
|
||||
if n == 0 {
|
||||
break // end of plaintext input
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seal encrypts the plaintext
|
||||
ciphertext := aesgcm.Seal(nil, nil, buf[:n], nil)
|
||||
|
||||
if _, err := output.Write(ciphertext); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// aesDecryptGCMFIPS reads from input, decrypts with AES-256 in a fips compliant
|
||||
// way and returns the reader to read the decrypted content from.
|
||||
func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
// Reader & verify header
|
||||
header := make([]byte, len(aesGcmFIPSHeader))
|
||||
if _, err := io.ReadFull(input, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if string(header) != aesGcmFIPSHeader {
|
||||
return nil, errors.New("invalid header")
|
||||
}
|
||||
|
||||
// Read salt
|
||||
salt := make([]byte, pbkdf2SaltLength)
|
||||
if _, err := io.ReadFull(input, salt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
// Initialize AES cipher block
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize a buffer to store decrypted data
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
// Decrypt the ciphertext in blocks
|
||||
for {
|
||||
// Create GCM mode with the cipher block
|
||||
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read a block of ciphertext from the input reader
|
||||
ciphertextBlock := make([]byte, aesGcmFIPSBlockSize+aesgcm.Overhead())
|
||||
n, err := io.ReadFull(input, ciphertextBlock)
|
||||
if n == 0 {
|
||||
break // end of ciphertext
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt the block of ciphertext
|
||||
plaintext, err := aesgcm.Open(nil, nil, ciphertextBlock[:n], nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.Write(plaintext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
// aesDecryptOFB reads from input, decrypts with AES-256 and returns the reader to a read decrypted content from.
|
||||
// AesEncrypt reads from input, encrypts with AES-256 and writes to the output.
|
||||
// passphrase is used to generate an encryption key.
|
||||
// note: This function used to decrypt files that were encrypted without a header i.e. old archives
|
||||
func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
// making a 32 bytes key that would correspond to AES-256
|
||||
// don't necessarily need a salt, so just kept in empty
|
||||
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
|
||||
key, err := scrypt.Key(passphrase, emptySalt, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the key is unique for each ciphertext, then it's ok to use a zero
|
||||
// IV.
|
||||
var iv [aes.BlockSize]byte
|
||||
stream := cipher.NewOFB(block, iv[:])
|
||||
|
||||
writer := &cipher.StreamWriter{S: stream, W: output}
|
||||
// Copy the input to the output, encrypting as we go.
|
||||
if _, err := io.Copy(writer, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AesDecrypt reads from input, decrypts with AES-256 and returns the reader to a read decrypted content from.
|
||||
// passphrase is used to generate an encryption key.
|
||||
func AesDecrypt(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
// making a 32 bytes key that would correspond to AES-256
|
||||
// don't necessarily need a salt, so just kept in empty
|
||||
key, err := scrypt.Key(passphrase, emptySalt, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -386,25 +59,12 @@ func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the key is unique for each ciphertext, then it's ok to use a zero IV.
|
||||
// If the key is unique for each ciphertext, then it's ok to use a zero
|
||||
// IV.
|
||||
var iv [aes.BlockSize]byte
|
||||
stream := cipher.NewOFB(block, iv[:])
|
||||
|
||||
reader := &cipher.StreamReader{S: stream, R: input}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// HasEncryptedHeader checks if the data has an encrypted header, note that fips
|
||||
// mode changes this behavior and so will only recognize data encrypted by the
|
||||
// same mode (fips enabled or disabled)
|
||||
func HasEncryptedHeader(data []byte) bool {
|
||||
return hasEncryptedHeader(data, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func hasEncryptedHeader(data []byte, fipsMode bool) bool {
|
||||
if fipsMode {
|
||||
return bytes.HasPrefix(data, []byte(aesGcmFIPSHeader))
|
||||
}
|
||||
|
||||
return bytes.HasPrefix(data, []byte(aesGcmHeader))
|
||||
}
|
||||
|
||||
@@ -1,444 +1,132 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"io"
|
||||
"math/rand"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/docker/docker/pkg/ioutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func randBytes(n int) []byte {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
type encryptFunc func(input io.Reader, output io.Writer, passphrase []byte) error
|
||||
type decryptFunc func(input io.Reader, passphrase []byte) (io.Reader, error)
|
||||
|
||||
func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
|
||||
const passphrase = "passphrase"
|
||||
tmpdir, _ := ioutils.TempDir("", "encrypt")
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc, decryptShouldSucceed bool) {
|
||||
tmpdir := t.TempDir()
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
content := []byte("content")
|
||||
ioutil.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
content := randBytes(1024*1024*100 + 523)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := ioutil.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
logs.CloseAndLogErr(encryptedFileWriter)
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
if !decryptShouldSucceed {
|
||||
require.Error(t, err, "Failed to decrypt file as indicated by decryptShouldSucceed")
|
||||
} else {
|
||||
require.NoError(t, err, "Failed to decrypt file indicated by decryptShouldSucceed")
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS, true)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM, true)
|
||||
})
|
||||
|
||||
t.Run("system_fips_mode_public_entry_points", func(t *testing.T) {
|
||||
// use the init mode, public entry points
|
||||
testFunc(t, AesEncrypt, AesDecrypt, true)
|
||||
})
|
||||
|
||||
t.Run("fips_encrypted_file_header_fails_in_non_fips_mode", func(t *testing.T) {
|
||||
// use aesDecrypt which checks the header, confirm that it fails
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, false)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCMFIPS, decrypt, false)
|
||||
})
|
||||
|
||||
t.Run("non_fips_encrypted_file_header_fails_in_fips_mode", func(t *testing.T) {
|
||||
// use aesDecrypt which checks the header, confirm that it fails
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, true)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCM, decrypt, false)
|
||||
})
|
||||
|
||||
t.Run("fips_encrypted_file_fails_in_non_fips_mode", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCM, false)
|
||||
})
|
||||
|
||||
t.Run("non_fips_encrypted_file_with_fips_mode_should_fail", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCMFIPS, false)
|
||||
})
|
||||
|
||||
t.Run("fips_with_base_aesDecrypt", func(t *testing.T) {
|
||||
// maximize coverage, use the base aesDecrypt function with valid fips mode
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, true)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCMFIPS, decrypt, true)
|
||||
})
|
||||
|
||||
t.Run("legacy", func(t *testing.T) {
|
||||
testFunc(t, legacyAesEncrypt, aesDecryptOFB, true)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
|
||||
const passphrase = "A strong passphrase with special characters: !@#$%^&*()_+"
|
||||
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
|
||||
content := randBytes(500)
|
||||
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
logs.CloseAndLogErr(encryptedFileWriter)
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withTheSamePasswordSmallFile(t *testing.T) {
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
|
||||
content := randBytes(500)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
|
||||
encryptedFileWriter, err := os.Create(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
logs.CloseAndLogErr(encryptedFileWriter)
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte("passphrase"))
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
decryptedContent, _ := ioutil.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withEmptyPassword(t *testing.T) {
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
tmpdir, _ := ioutils.TempDir("", "encrypt")
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1024 * 50)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
content := []byte("content")
|
||||
ioutil.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, err := os.Create(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileWriter)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := ioutil.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(""))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(""))
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
decryptedContent, _ := ioutil.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T) {
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
tmpdir, _ := ioutils.TempDir("", "encrypt")
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1034)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
content := []byte("content")
|
||||
ioutil.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, err := os.Create(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileWriter)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := ioutil.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
_, err = decrypt(encryptedFileReader, []byte("garbage"))
|
||||
require.Error(t, err, "Should not allow decrypt with wrong passphrase")
|
||||
}
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte("garbage"))
|
||||
assert.Nil(t, err, "Should allow to decrypt with wrong passphrase")
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
}
|
||||
|
||||
func legacyAesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var iv [aes.BlockSize]byte
|
||||
stream := cipher.NewOFB(block, iv[:])
|
||||
|
||||
writer := &cipher.StreamWriter{S: stream, W: output}
|
||||
if _, err := io.Copy(writer, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_hasEncryptedHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
fipsMode bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-FIPS mode with valid header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-FIPS mode with FIPS header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with valid header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with non-FIPS header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header",
|
||||
data: []byte("INVALID-HEADER" + "some data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := hasEncryptedHeader(tt.data, tt.fipsMode)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
decryptedContent, _ := ioutil.ReadFile(decryptedFilePath)
|
||||
assert.NotEqual(t, content, decryptedContent, "Original and decrypted content should NOT match")
|
||||
}
|
||||
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
|
||||
"github.com/portainer/portainer/pkg/libcrypto"
|
||||
"github.com/portainer/libcrypto"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -112,7 +113,10 @@ func (service *ECDSAService) CreateSignature(message string) (string, error) {
|
||||
message = service.secret
|
||||
}
|
||||
|
||||
hash := libcrypto.InsecureHashFromBytes([]byte(message))
|
||||
hash := libcrypto.HashFromBytes([]byte(message))
|
||||
|
||||
r := big.NewInt(0)
|
||||
s := big.NewInt(0)
|
||||
|
||||
r, s, err := ecdsa.Sign(rand.Reader, service.privateKey, hash)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateSignature(t *testing.T) {
|
||||
var s = NewECDSAService("secret")
|
||||
|
||||
privKey, pubKey, err := s.GenerateKeyPair()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, privKey)
|
||||
require.NotEmpty(t, pubKey)
|
||||
|
||||
m := "test message"
|
||||
r, err := s.CreateSignature(m)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, r, m)
|
||||
require.NotEmpty(t, r)
|
||||
}
|
||||
@@ -1,24 +1,22 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
// Not allowed in FIPS mode
|
||||
"golang.org/x/crypto/bcrypt" //nolint:depguard
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Service represents a service for encrypting/hashing data.
|
||||
type Service struct{}
|
||||
|
||||
// Hash hashes a string using the bcrypt algorithm
|
||||
func (Service) Hash(data string) (string, error) {
|
||||
func (*Service) Hash(data string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(data), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CompareHashAndData compares a hash to clear data and returns an error if the comparison fails.
|
||||
func (Service) CompareHashAndData(hash string, data string) error {
|
||||
func (*Service) CompareHashAndData(hash string, data string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(data))
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_Hash(t *testing.T) {
|
||||
var s = Service{}
|
||||
var s = &Service{}
|
||||
|
||||
type args struct {
|
||||
hash string
|
||||
@@ -53,11 +51,3 @@ func TestService_Hash(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
s := Service{}
|
||||
|
||||
hash, err := s.Hash("Passw0rd!")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, hash)
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Nonce struct {
|
||||
val []byte
|
||||
}
|
||||
|
||||
func NewNonce(size int) *Nonce {
|
||||
return &Nonce{val: make([]byte, size)}
|
||||
}
|
||||
|
||||
// NewRandomNonce generates a new initial nonce with the lower byte set to a random value
|
||||
// This ensures there are plenty of nonce values available before rolling over
|
||||
// Based on ideas from the Secure Programming Cookbook for C and C++ by John Viega, Matt Messier
|
||||
// https://www.oreilly.com/library/view/secure-programming-cookbook/0596003943/ch04s09.html
|
||||
func NewRandomNonce(size int) (*Nonce, error) {
|
||||
randomBytes := 1
|
||||
if size <= randomBytes {
|
||||
return nil, errors.New("nonce size must be greater than the number of random bytes")
|
||||
}
|
||||
|
||||
randomPart := make([]byte, randomBytes)
|
||||
if _, err := rand.Read(randomPart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zeroPart := make([]byte, size-randomBytes)
|
||||
nonceVal := append(randomPart, zeroPart...)
|
||||
return &Nonce{val: nonceVal}, nil
|
||||
}
|
||||
|
||||
func (n *Nonce) Read(stream io.Reader) error {
|
||||
_, err := io.ReadFull(stream, n.val)
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *Nonce) Value() []byte {
|
||||
return n.val
|
||||
}
|
||||
|
||||
func (n *Nonce) Increment() error {
|
||||
// Start incrementing from the least significant byte
|
||||
for i := len(n.val) - 1; i >= 0; i-- {
|
||||
// Increment the current byte
|
||||
n.val[i]++
|
||||
|
||||
// Check for overflow
|
||||
if n.val[i] != 0 {
|
||||
// No overflow, nonce is successfully incremented
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, it means the nonce has overflowed
|
||||
return errors.New("nonce overflow")
|
||||
}
|
||||
@@ -3,33 +3,12 @@ package crypto
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings
|
||||
func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
return createTLSConfiguration(fips.FIPSMode(), insecureSkipVerify)
|
||||
}
|
||||
|
||||
func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
if fipsEnabled {
|
||||
return &tls.Config{ //nolint:forbidigo
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
},
|
||||
CurvePreferences: []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521},
|
||||
}
|
||||
}
|
||||
|
||||
return &tls.Config{ //nolint:forbidigo
|
||||
// CreateServerTLSConfiguration creates a basic tls.Config to be used by servers with recommended TLS settings
|
||||
func CreateServerTLSConfiguration() *tls.Config {
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
@@ -41,42 +20,25 @@ func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Conf
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
},
|
||||
InsecureSkipVerify: insecureSkipVerify, //nolint:forbidigo
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from memory.
|
||||
func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
return createTLSConfigurationFromBytes(fips.FIPSMode(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
|
||||
}
|
||||
func CreateTLSConfigurationFromBytes(caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) {
|
||||
config := &tls.Config{}
|
||||
config.InsecureSkipVerify = skipServerVerification
|
||||
|
||||
func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
if !useTLS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
config := createTLSConfiguration(fipsEnabled, skipServerVerification)
|
||||
|
||||
if !skipClientVerification || fipsEnabled {
|
||||
if !skipClientVerification {
|
||||
certificate, err := tls.X509KeyPair(cert, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
if !skipServerVerification || fipsEnabled {
|
||||
if !skipServerVerification {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
config.RootCAs = caCertPool
|
||||
@@ -87,38 +49,29 @@ func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key
|
||||
|
||||
// CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from disk.
|
||||
func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
return createTLSConfigurationFromDisk(fips.FIPSMode(), config)
|
||||
}
|
||||
func CreateTLSConfigurationFromDisk(caCertPath, certPath, keyPath string, skipServerVerification bool) (*tls.Config, error) {
|
||||
config := &tls.Config{}
|
||||
config.InsecureSkipVerify = skipServerVerification
|
||||
|
||||
func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
if !config.TLS && fipsEnabled {
|
||||
return nil, fips.ErrTLSRequired
|
||||
} else if !config.TLS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := createTLSConfiguration(fipsEnabled, config.TLSSkipVerify)
|
||||
|
||||
if config.TLSCertPath != "" && config.TLSKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.TLSCertPath, config.TLSKeyPath)
|
||||
if certPath != "" && keyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
if !tlsConfig.InsecureSkipVerify && config.TLSCACertPath != "" { //nolint:forbidigo
|
||||
caCert, err := os.ReadFile(config.TLSCACertPath)
|
||||
if !skipServerVerification && caCertPath != "" {
|
||||
caCert, err := ioutil.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
config.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateTLSConfiguration(t *testing.T) {
|
||||
// InsecureSkipVerify = false
|
||||
config := CreateTLSConfiguration(false)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
|
||||
// InsecureSkipVerify = true
|
||||
config = CreateTLSConfiguration(true)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.True(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFIPS(t *testing.T) {
|
||||
fips := true
|
||||
|
||||
fipsCipherSuites := []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
fipsCurvePreferences := []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521}
|
||||
|
||||
config := createTLSConfiguration(fips, false)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.Equal(t, config.MaxVersion, uint16(tls.VersionTLS13)) //nolint:forbidigo
|
||||
require.Equal(t, config.CipherSuites, fipsCipherSuites) //nolint:forbidigo
|
||||
require.Equal(t, config.CurvePreferences, fipsCurvePreferences) //nolint:forbidigo
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromBytes(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromBytes(false, nil, nil, nil, false, false)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS client/server verifications
|
||||
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, true, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
// Empty TLS
|
||||
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, false, false)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, config)
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromDisk(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS verifications
|
||||
config, err = CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromDiskFIPS(t *testing.T) {
|
||||
fips := true
|
||||
|
||||
// Skipping TLS verifications cannot be done in FIPS mode
|
||||
config, err := createTLSConfigurationFromDisk(fips, portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
@@ -5,25 +5,19 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
dserrors "github.com/portainer/portainer/api/dataservices/errors"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
DatabaseFileName = "portainer.db"
|
||||
EncryptedDatabaseFileName = "portainer.edb"
|
||||
|
||||
txMaxSize = 65536
|
||||
compactedSuffix = ".compacted"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -38,7 +32,6 @@ type DbConnection struct {
|
||||
InitialMmapSize int
|
||||
EncryptionKey []byte
|
||||
isEncrypted bool
|
||||
Compact bool
|
||||
|
||||
*bolt.DB
|
||||
}
|
||||
@@ -66,15 +59,6 @@ func (connection *DbConnection) GetStorePath() string {
|
||||
return connection.Path
|
||||
}
|
||||
|
||||
func (connection *DbConnection) GetDatabaseFileSize() (int64, error) {
|
||||
file, err := os.Stat(connection.GetDatabaseFilePath())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Failed to stat database file path: %s err: %w", connection.GetDatabaseFilePath(), err)
|
||||
}
|
||||
|
||||
return file.Size(), nil
|
||||
}
|
||||
|
||||
func (connection *DbConnection) SetEncrypted(flag bool) {
|
||||
connection.isEncrypted = flag
|
||||
}
|
||||
@@ -87,6 +71,7 @@ func (connection *DbConnection) IsEncryptedStore() bool {
|
||||
// NeedsEncryptionMigration returns true if database encryption is enabled and
|
||||
// we have an un-encrypted DB that requires migration to an encrypted DB
|
||||
func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
|
||||
|
||||
// Cases: Note, we need to check both portainer.db and portainer.edb
|
||||
// to determine if it's a new store. We only need to differentiate between cases 2,3 and 5
|
||||
|
||||
@@ -134,77 +119,38 @@ func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
|
||||
|
||||
// Open opens and initializes the BoltDB database.
|
||||
func (connection *DbConnection) Open() error {
|
||||
log.Info().Str("filename", connection.GetDatabaseFileName()).Msg("loading PortainerDB")
|
||||
|
||||
logrus.Infof("Loading PortainerDB: %s", connection.GetDatabaseFileName())
|
||||
|
||||
// Now we open the db
|
||||
databasePath := connection.GetDatabaseFilePath()
|
||||
db, err := bolt.Open(databasePath, 0600, connection.boltOptions(connection.Compact))
|
||||
db, err := bolt.Open(databasePath, 0600, &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.MaxBatchSize = connection.MaxBatchSize
|
||||
db.MaxBatchDelay = connection.MaxBatchDelay
|
||||
connection.DB = db
|
||||
|
||||
if connection.Compact {
|
||||
log.Info().Msg("compacting database")
|
||||
if err := connection.compact(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to compact database")
|
||||
|
||||
// Close the read-only database and re-open in read-write mode
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after failed compaction")
|
||||
}
|
||||
|
||||
connection.Compact = false
|
||||
|
||||
return connection.Open()
|
||||
} else {
|
||||
log.Info().Msg("database compaction completed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the BoltDB database.
|
||||
// Safe to being called multiple times.
|
||||
func (connection *DbConnection) Close() error {
|
||||
log.Info().Msg("closing PortainerDB")
|
||||
|
||||
if connection.DB != nil {
|
||||
return connection.DB.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (connection *DbConnection) txFn(fn func(portainer.Transaction) error) func(*bolt.Tx) error {
|
||||
return func(tx *bolt.Tx) error {
|
||||
return fn(&DbTransaction{conn: connection, tx: tx})
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTx executes the given function inside a read-write transaction
|
||||
func (connection *DbConnection) UpdateTx(fn func(portainer.Transaction) error) error {
|
||||
if connection.MaxBatchDelay > 0 && connection.MaxBatchSize > 1 {
|
||||
return connection.Batch(connection.txFn(fn))
|
||||
}
|
||||
|
||||
return connection.Update(connection.txFn(fn))
|
||||
}
|
||||
|
||||
// ViewTx executes the given function inside a read-only transaction
|
||||
func (connection *DbConnection) ViewTx(fn func(portainer.Transaction) error) error {
|
||||
return connection.View(connection.txFn(fn))
|
||||
}
|
||||
|
||||
// BackupTo backs up db to a provided writer.
|
||||
// It does hot backup and doesn't block other database reads and writes
|
||||
func (connection *DbConnection) BackupTo(w io.Writer) error {
|
||||
return connection.View(func(tx *bolt.Tx) error {
|
||||
_, err := tx.WriteTo(w)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
@@ -212,15 +158,14 @@ func (connection *DbConnection) BackupTo(w io.Writer) error {
|
||||
func (connection *DbConnection) ExportRaw(filename string) error {
|
||||
databasePath := connection.GetDatabaseFilePath()
|
||||
if _, err := os.Stat(databasePath); err != nil {
|
||||
return fmt.Errorf("stat on %s failed, error: %w", databasePath, err)
|
||||
return fmt.Errorf("stat on %s failed: %s", databasePath, err)
|
||||
}
|
||||
|
||||
b, err := connection.ExportJSON(databasePath, true)
|
||||
b, err := connection.ExportJson(databasePath, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, b, 0600)
|
||||
return ioutil.WriteFile(filename, b, 0600)
|
||||
}
|
||||
|
||||
// ConvertToKey returns an 8-byte big endian representation of v.
|
||||
@@ -229,62 +174,39 @@ func (connection *DbConnection) ExportRaw(filename string) error {
|
||||
func (connection *DbConnection) ConvertToKey(v int) []byte {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(v))
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// keyToString Converts a key to a string value suitable for logging
|
||||
func keyToString(b []byte) string {
|
||||
if len(b) != 8 {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
v := binary.BigEndian.Uint64(b)
|
||||
if v <= math.MaxInt32 {
|
||||
return strconv.FormatUint(v, 10)
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// CreateBucket is a generic function used to create a bucket inside a database.
|
||||
// CreateBucket is a generic function used to create a bucket inside a database database.
|
||||
func (connection *DbConnection) SetServiceName(bucketName string) error {
|
||||
return connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.SetServiceName(bucketName)
|
||||
})
|
||||
}
|
||||
|
||||
// GetObject is a generic function used to retrieve an unmarshalled object from a database.
|
||||
func (connection *DbConnection) GetObject(bucketName string, key []byte, object any) error {
|
||||
return connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetObject(bucketName, key, object)
|
||||
})
|
||||
}
|
||||
|
||||
func (connection *DbConnection) GetRawBytes(bucketName string, key []byte) ([]byte, error) {
|
||||
var value []byte
|
||||
|
||||
err := connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
value, err = tx.GetRawBytes(bucketName, key)
|
||||
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(bucketName))
|
||||
return err
|
||||
})
|
||||
|
||||
return value, err
|
||||
}
|
||||
|
||||
func (connection *DbConnection) KeyExists(bucketName string, key []byte) (bool, error) {
|
||||
var exists bool
|
||||
// GetObject is a generic function used to retrieve an unmarshalled object from a database database.
|
||||
func (connection *DbConnection) GetObject(bucketName string, key []byte, object interface{}) error {
|
||||
var data []byte
|
||||
|
||||
err := connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
exists, err = tx.KeyExists(bucketName, key)
|
||||
err := connection.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
|
||||
return err
|
||||
value := bucket.Get(key)
|
||||
if value == nil {
|
||||
return dserrors.ErrObjectNotFound
|
||||
}
|
||||
|
||||
data = make([]byte, len(value))
|
||||
copy(data, value)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return exists, err
|
||||
return connection.UnmarshalObjectWithJsoniter(data, object)
|
||||
}
|
||||
|
||||
func (connection *DbConnection) getEncryptionKey() []byte {
|
||||
@@ -295,51 +217,50 @@ func (connection *DbConnection) getEncryptionKey() []byte {
|
||||
return connection.EncryptionKey
|
||||
}
|
||||
|
||||
// UpdateObject is a generic function used to update an object inside a database.
|
||||
func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object any) error {
|
||||
return connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.UpdateObject(bucketName, key, object)
|
||||
})
|
||||
}
|
||||
// UpdateObject is a generic function used to update an object inside a database database.
|
||||
func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object interface{}) error {
|
||||
data, err := connection.MarshalObject(object)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateObjectFunc is a generic function used to update an object safely without race conditions.
|
||||
func (connection *DbConnection) UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error {
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
|
||||
data := bucket.Get(key)
|
||||
if data == nil {
|
||||
return fmt.Errorf("%w (bucket=%s, key=%s)", dserrors.ErrObjectNotFound, bucketName, keyToString(key))
|
||||
}
|
||||
|
||||
err := connection.UnmarshalObject(data, object)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateFn()
|
||||
|
||||
data, err = connection.MarshalObject(object)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(key, data)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteObject is a generic function used to delete an object inside a database.
|
||||
// DeleteObject is a generic function used to delete an object inside a database database.
|
||||
func (connection *DbConnection) DeleteObject(bucketName string, key []byte) error {
|
||||
return connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.DeleteObject(bucketName, key)
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
return bucket.Delete(key)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAllObjects delete all objects where matching() returns (id, ok).
|
||||
// TODO: think about how to return the error inside (maybe change ok to type err, and use "notfound"?
|
||||
func (connection *DbConnection) DeleteAllObjects(bucketName string, obj any, matching func(o any) (id int, ok bool)) error {
|
||||
return connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.DeleteAllObjects(bucketName, obj, matching)
|
||||
func (connection *DbConnection) DeleteAllObjects(bucketName string, matching func(o interface{}) (id int, ok bool)) error {
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
|
||||
cursor := bucket.Cursor()
|
||||
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
|
||||
var obj interface{}
|
||||
err := connection.UnmarshalObject(v, &obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if id, ok := matching(obj); ok {
|
||||
err := bucket.Delete(connection.ConvertToKey(id))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -347,8 +268,13 @@ func (connection *DbConnection) DeleteAllObjects(bucketName string, obj any, mat
|
||||
func (connection *DbConnection) GetNextIdentifier(bucketName string) int {
|
||||
var identifier int
|
||||
|
||||
_ = connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
identifier = tx.GetNextIdentifier(bucketName)
|
||||
connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
id, err := bucket.NextSequence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
identifier = int(id)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -356,64 +282,136 @@ func (connection *DbConnection) GetNextIdentifier(bucketName string) int {
|
||||
}
|
||||
|
||||
// CreateObject creates a new object in the bucket, using the next bucket sequence id
|
||||
func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, any)) error {
|
||||
return connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.CreateObject(bucketName, fn)
|
||||
func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error {
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
|
||||
seqId, _ := bucket.NextSequence()
|
||||
id, obj := fn(seqId)
|
||||
|
||||
data, err := connection.MarshalObject(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(connection.ConvertToKey(int(id)), data)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateObjectWithId creates a new object in the bucket, using the specified id
|
||||
func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj any) error {
|
||||
return connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.CreateObjectWithId(bucketName, id, obj)
|
||||
func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj interface{}) error {
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
data, err := connection.MarshalObject(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(connection.ConvertToKey(id), data)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateObjectWithStringId creates a new object in the bucket, using the specified id
|
||||
func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj any) error {
|
||||
return connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.CreateObjectWithStringId(bucketName, id, obj)
|
||||
func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error {
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
data, err := connection.MarshalObject(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(id, data)
|
||||
})
|
||||
}
|
||||
|
||||
func (connection *DbConnection) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error {
|
||||
return connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetAll(bucketName, obj, appendFn)
|
||||
// CreateObjectWithSetSequence creates a new object in the bucket, using the specified id, and sets the bucket sequence
|
||||
// avoid this :)
|
||||
func (connection *DbConnection) CreateObjectWithSetSequence(bucketName string, id int, obj interface{}) error {
|
||||
return connection.Batch(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
|
||||
// We manually manage sequences for schedules
|
||||
err := bucket.SetSequence(uint64(id))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := connection.MarshalObject(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(connection.ConvertToKey(id), data)
|
||||
})
|
||||
}
|
||||
|
||||
func (connection *DbConnection) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj any, appendFn func(o any) (any, error)) error {
|
||||
return connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetAllWithKeyPrefix(bucketName, keyPrefix, obj, appendFn)
|
||||
func (connection *DbConnection) GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
|
||||
err := connection.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
cursor := bucket.Cursor()
|
||||
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
|
||||
err := connection.UnmarshalObject(v, obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
obj, err = append(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// BackupMetadata will return a copy of the boltdb sequence numbers for all buckets.
|
||||
func (connection *DbConnection) BackupMetadata() (map[string]any, error) {
|
||||
buckets := map[string]any{}
|
||||
// TODO: decide which Unmarshal to use, and use one...
|
||||
func (connection *DbConnection) GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
|
||||
err := connection.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketName))
|
||||
|
||||
cursor := bucket.Cursor()
|
||||
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
|
||||
err := connection.UnmarshalObjectWithJsoniter(v, obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
obj, err = append(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (connection *DbConnection) BackupMetadata() (map[string]interface{}, error) {
|
||||
buckets := map[string]interface{}{}
|
||||
|
||||
err := connection.View(func(tx *bolt.Tx) error {
|
||||
return tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
|
||||
err := tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
|
||||
bucketName := string(name)
|
||||
bucket = tx.Bucket([]byte(bucketName))
|
||||
seqId := bucket.Sequence()
|
||||
buckets[bucketName] = int(seqId)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return buckets, err
|
||||
}
|
||||
|
||||
// RestoreMetadata will restore the boltdb sequence numbers for all buckets.
|
||||
func (connection *DbConnection) RestoreMetadata(s map[string]any) error {
|
||||
func (connection *DbConnection) RestoreMetadata(s map[string]interface{}) error {
|
||||
var err error
|
||||
|
||||
for bucketName, v := range s {
|
||||
id, ok := v.(float64) // JSON ints are unmarshalled to interface as float64. See: https://pkg.go.dev/encoding/json#Decoder.Decode
|
||||
if !ok {
|
||||
log.Error().Str("bucket", bucketName).Msg("failed to restore metadata to bucket, skipped")
|
||||
|
||||
logrus.Errorf("Failed to restore metadata to bucket %s, skipped", bucketName)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -422,55 +420,9 @@ func (connection *DbConnection) RestoreMetadata(s map[string]any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.SetSequence(uint64(id))
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// compact attempts to compact the database and replace it iff it succeeds
|
||||
func (connection *DbConnection) compact() (err error) {
|
||||
compactedPath := connection.GetDatabaseFilePath() + compactedSuffix
|
||||
|
||||
if err := os.Remove(compactedPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("failure to remove an existing compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB, err := bolt.Open(compactedPath, 0o600, connection.boltOptions(false))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failure to create the compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB.MaxBatchSize = connection.MaxBatchSize
|
||||
compactedDB.MaxBatchDelay = connection.MaxBatchDelay
|
||||
|
||||
if err := bolt.Compact(compactedDB, connection.DB, txMaxSize); err != nil {
|
||||
return fmt.Errorf("failure to compact the database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := os.Rename(compactedPath, connection.GetDatabaseFilePath()); err != nil {
|
||||
return fmt.Errorf("failure to move the compacted database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after compaction")
|
||||
}
|
||||
|
||||
connection.DB = compactedDB
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (connection *DbConnection) boltOptions(readOnly bool) *bolt.Options {
|
||||
return &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
FreelistType: bolt.FreelistMapType,
|
||||
NoFreelistSync: true,
|
||||
ReadOnly: readOnly,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,7 @@ import (
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
@@ -91,43 +87,28 @@ func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
connection := DbConnection{Path: dir}
|
||||
|
||||
if tc.dbname == "both" {
|
||||
// Special case. If portainer.db and portainer.edb exist.
|
||||
dbFile1 := path.Join(connection.Path, DatabaseFileName)
|
||||
f, _ := os.Create(dbFile1)
|
||||
|
||||
err := f.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := os.Remove(dbFile1)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
f.Close()
|
||||
defer os.Remove(dbFile1)
|
||||
|
||||
dbFile2 := path.Join(connection.Path, EncryptedDatabaseFileName)
|
||||
f, _ = os.Create(dbFile2)
|
||||
|
||||
err = f.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := os.Remove(dbFile2)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
f.Close()
|
||||
defer os.Remove(dbFile2)
|
||||
} else if tc.dbname != "" {
|
||||
dbFile := path.Join(connection.Path, tc.dbname)
|
||||
f, _ := os.Create(dbFile)
|
||||
|
||||
err := f.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := os.Remove(dbFile)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
f.Close()
|
||||
defer os.Remove(dbFile)
|
||||
}
|
||||
|
||||
if tc.key {
|
||||
@@ -141,60 +122,3 @@ func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBCompaction(t *testing.T) {
|
||||
db := &DbConnection{Path: t.TempDir()}
|
||||
|
||||
err := db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *bbolt.Tx) error {
|
||||
b, err := tx.CreateBucketIfNotExists([]byte("testbucket"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.Put([]byte("key"), []byte("value"))
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reopen the DB to trigger compaction
|
||||
db.Compact = true
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the data is still there
|
||||
err = db.View(func(tx *bbolt.Tx) error {
|
||||
b := tx.Bucket([]byte("testbucket"))
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := b.Get([]byte("key"))
|
||||
require.Equal(t, []byte("value"), val)
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Failures
|
||||
compactedPath := db.GetDatabaseFilePath() + compactedSuffix
|
||||
err = os.Mkdir(compactedPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := os.Create(filesystem.JoinPaths(compactedPath, "somefile"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
package boltdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func backupMetadata(connection *bolt.DB) (map[string]any, error) {
|
||||
buckets := map[string]any{}
|
||||
func backupMetadata(connection *bolt.DB) (map[string]interface{}, error) {
|
||||
buckets := map[string]interface{}{}
|
||||
|
||||
err := connection.View(func(tx *bolt.Tx) error {
|
||||
err := tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
|
||||
bucketName := string(name)
|
||||
bucket = tx.Bucket([]byte(bucketName))
|
||||
seqId := bucket.Sequence()
|
||||
buckets[bucketName] = int(seqId)
|
||||
return nil
|
||||
@@ -28,79 +28,72 @@ func backupMetadata(connection *bolt.DB) (map[string]any, error) {
|
||||
|
||||
// ExportJSON creates a JSON representation from a DbConnection. You can include
|
||||
// the database's metadata or ignore it. Ensure the database is closed before
|
||||
// using this function.
|
||||
// using this function
|
||||
// inspired by github.com/konoui/boltdb-exporter (which has no license)
|
||||
// but very much simplified, based on how we use boltdb
|
||||
func (c *DbConnection) ExportJSON(databasePath string, metadata bool) ([]byte, error) {
|
||||
log.Debug().Str("databasePath", databasePath).Msg("exportJson")
|
||||
func (c *DbConnection) ExportJson(databasePath string, metadata bool) ([]byte, error) {
|
||||
logrus.WithField("databasePath", databasePath).Infof("exportJson")
|
||||
|
||||
connection, err := bolt.Open(databasePath, 0600, &bolt.Options{Timeout: 1 * time.Second, ReadOnly: true})
|
||||
if err != nil {
|
||||
return []byte("{}"), err
|
||||
}
|
||||
defer logs.CloseAndLogErr(connection)
|
||||
defer connection.Close()
|
||||
|
||||
backup := make(map[string]any)
|
||||
backup := make(map[string]interface{})
|
||||
if metadata {
|
||||
meta, err := backupMetadata(connection)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed exporting metadata")
|
||||
logrus.WithError(err).Errorf("Failed exporting metadata: %v", err)
|
||||
}
|
||||
|
||||
backup["__metadata"] = meta
|
||||
}
|
||||
|
||||
if err := connection.View(func(tx *bolt.Tx) error {
|
||||
return tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
|
||||
err = connection.View(func(tx *bolt.Tx) error {
|
||||
err = tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
|
||||
bucketName := string(name)
|
||||
var list []any
|
||||
var list []interface{}
|
||||
version := make(map[string]string)
|
||||
cursor := bucket.Cursor()
|
||||
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var obj any
|
||||
var obj interface{}
|
||||
err := c.UnmarshalObject(v, &obj)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("bucket", bucketName).
|
||||
Str("object", string(v)).
|
||||
Err(err).
|
||||
Msg("failed to unmarshal")
|
||||
|
||||
logrus.WithError(err).Errorf("Failed to unmarshal (bucket %s): %v", bucketName, string(v))
|
||||
obj = v
|
||||
}
|
||||
|
||||
if bucketName == "version" {
|
||||
version[string(k)] = string(v)
|
||||
} else {
|
||||
list = append(list, obj)
|
||||
}
|
||||
}
|
||||
|
||||
if bucketName == "version" {
|
||||
backup[bucketName] = version
|
||||
return nil
|
||||
}
|
||||
|
||||
if bucketName == "ssl" ||
|
||||
bucketName == "settings" ||
|
||||
bucketName == "tunnel_server" {
|
||||
backup[bucketName] = nil
|
||||
if len(list) > 0 {
|
||||
backup[bucketName] = list[0]
|
||||
if len(list) > 0 {
|
||||
if bucketName == "ssl" ||
|
||||
bucketName == "settings" ||
|
||||
bucketName == "tunnel_server" {
|
||||
backup[bucketName] = nil
|
||||
if len(list) > 0 {
|
||||
backup[bucketName] = list[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
backup[bucketName] = list
|
||||
return nil
|
||||
}
|
||||
|
||||
backup[bucketName] = list
|
||||
|
||||
return nil
|
||||
})
|
||||
}); err != nil {
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return []byte("{}"), err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,42 +1,38 @@
|
||||
package boltdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/segmentio/encoding/json"
|
||||
)
|
||||
|
||||
var errEncryptedStringTooShort = errors.New("encrypted string too short")
|
||||
var errEncryptedStringTooShort = fmt.Errorf("encrypted string too short")
|
||||
|
||||
// MarshalObject encodes an object to binary format
|
||||
func (connection *DbConnection) MarshalObject(object any) ([]byte, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
func (connection *DbConnection) MarshalObject(object interface{}) (data []byte, err error) {
|
||||
// Special case for the VERSION bucket. Here we're not using json
|
||||
if v, ok := object.(string); ok {
|
||||
buf.WriteString(v)
|
||||
data = []byte(v)
|
||||
} else {
|
||||
enc := json.NewEncoder(buf)
|
||||
enc.SetSortMapKeys(false)
|
||||
enc.SetAppendNewline(false)
|
||||
|
||||
if err := enc.Encode(object); err != nil {
|
||||
return nil, err
|
||||
data, err = json.Marshal(object)
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
}
|
||||
|
||||
if connection.getEncryptionKey() == nil {
|
||||
return buf.Bytes(), nil
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return encrypt(buf.Bytes(), connection.getEncryptionKey())
|
||||
return encrypt(data, connection.getEncryptionKey())
|
||||
}
|
||||
|
||||
// UnmarshalObject decodes an object from binary data
|
||||
func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
|
||||
func (connection *DbConnection) UnmarshalObject(data []byte, object interface{}) error {
|
||||
var err error
|
||||
if connection.getEncryptionKey() != nil {
|
||||
data, err = decrypt(data, connection.getEncryptionKey())
|
||||
@@ -44,60 +40,91 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
|
||||
return errors.Wrap(err, "Failed decrypting object")
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, object); err != nil {
|
||||
e := json.Unmarshal(data, object)
|
||||
if e != nil {
|
||||
// Special case for the VERSION bucket. Here we're not using json
|
||||
// So we need to return it as a string
|
||||
s, ok := object.(*string)
|
||||
if !ok {
|
||||
return errors.Wrap(err, "Failed unmarshalling object")
|
||||
return errors.Wrap(err, e.Error())
|
||||
}
|
||||
|
||||
*s = string(data)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// UnmarshalObjectWithJsoniter decodes an object from binary data
|
||||
// using the jsoniter library. It is mainly used to accelerate environment(endpoint)
|
||||
// decoding at the moment.
|
||||
func (connection *DbConnection) UnmarshalObjectWithJsoniter(data []byte, object interface{}) error {
|
||||
if connection.getEncryptionKey() != nil {
|
||||
var err error
|
||||
data, err = decrypt(data, connection.getEncryptionKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var jsoni = jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
err := jsoni.Unmarshal(data, &object)
|
||||
if err != nil {
|
||||
if s, ok := object.(*string); ok {
|
||||
*s = string(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mmm, don't have a KMS .... aes GCM seems the most likely from
|
||||
// https://gist.github.com/atoponce/07d8d4c833873be2f68c34f9afc5a78a#symmetric-encryption
|
||||
|
||||
func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
|
||||
block, err := aes.NewCipher(passphrase)
|
||||
block, _ := aes.NewCipher(passphrase)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
// NewGCMWithRandomNonce in go 1.24 handles setting up the nonce and adding it to the encrypted output
|
||||
gcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nil, nil, plaintext, nil), nil
|
||||
ciphertextByte := gcm.Seal(
|
||||
nonce,
|
||||
nonce,
|
||||
plaintext,
|
||||
nil)
|
||||
return ciphertextByte, nil
|
||||
}
|
||||
|
||||
func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err error) {
|
||||
if string(encrypted) == "false" {
|
||||
return []byte("false"), nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(passphrase)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating cypher block")
|
||||
}
|
||||
|
||||
// NewGCMWithRandomNonce in go 1.24 handles reading the nonce from the encrypted input for us
|
||||
gcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating GCM")
|
||||
}
|
||||
|
||||
if len(encrypted) < gcm.NonceSize() {
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(encrypted) < nonceSize {
|
||||
return encrypted, errEncryptedStringTooShort
|
||||
}
|
||||
|
||||
plaintextByte, err = gcm.Open(nil, nil, encrypted, nil)
|
||||
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
|
||||
plaintextByte, err = gcm.Open(
|
||||
nil,
|
||||
nonce,
|
||||
ciphertextByteClean,
|
||||
nil)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error decrypting text")
|
||||
}
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
package boltdb
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonobject = `{"LogoURL":"","BlackListedLabels":[],"AuthenticationMethod":1,"InternalAuthSettings": {"RequiredPasswordLength": 12}"LDAPSettings":{"AnonymousMode":true,"ReaderDN":"","URL":"","TLSConfig":{"TLS":false,"TLSSkipVerify":false},"StartTLS":false,"SearchSettings":[{"BaseDN":"","Filter":"","UserNameAttribute":""}],"GroupSearchSettings":[{"GroupBaseDN":"","GroupFilter":"","GroupAttribute":""}],"AutoCreateUsers":true},"OAuthSettings":{"ClientID":"","AccessTokenURI":"","AuthorizationURI":"","ResourceURI":"","RedirectURI":"","UserIdentifier":"","Scopes":"","OAuthAutoCreateUsers":false,"DefaultTeamID":0,"SSO":true,"LogoutURI":"","KubeSecretKey":"j0zLVtY/lAWBk62ByyF0uP80SOXaitsABP0TTJX8MhI="},"OpenAMTConfiguration":{"Enabled":false,"MPSServer":"","MPSUser":"","MPSPassword":"","MPSToken":"","CertFileContent":"","CertFileName":"","CertFilePassword":"","DomainName":""},"FeatureFlagSettings":{},"SnapshotInterval":"5m","TemplatesURL":"https://raw.githubusercontent.com/portainer/templates/master/templates-2.0.json","EdgeAgentCheckinInterval":5,"EnableEdgeComputeFeatures":false,"UserSessionTimeout":"8h","KubeconfigExpiry":"0","HelmRepositoryURL":"https://charts.bitnami.com/bitnami","KubectlShellImage":"portainer/kubectl-shell","DisplayDonationHeader":false,"DisplayExternalContributors":false,"EnableHostManagementFeatures":false,"AllowVolumeBrowserForRegularUsers":false,"AllowBindMountsForRegularUsers":false,"AllowPrivilegedModeForRegularUsers":false,"AllowHostNamespaceForRegularUsers":false,"AllowStackManagementForRegularUsers":false,"AllowDeviceMappingForRegularUsers":false,"AllowContainerCapabilitiesForRegularUsers":false}`
|
||||
jsonobject = `{"LogoURL":"","BlackListedLabels":[],"AuthenticationMethod":1,"InternalAuthSettings": {"RequiredPasswordLength": 12}"LDAPSettings":{"AnonymousMode":true,"ReaderDN":"","URL":"","TLSConfig":{"TLS":false,"TLSSkipVerify":false},"StartTLS":false,"SearchSettings":[{"BaseDN":"","Filter":"","UserNameAttribute":""}],"GroupSearchSettings":[{"GroupBaseDN":"","GroupFilter":"","GroupAttribute":""}],"AutoCreateUsers":true},"OAuthSettings":{"ClientID":"","AccessTokenURI":"","AuthorizationURI":"","ResourceURI":"","RedirectURI":"","UserIdentifier":"","Scopes":"","OAuthAutoCreateUsers":false,"DefaultTeamID":0,"SSO":true,"LogoutURI":"","KubeSecretKey":"j0zLVtY/lAWBk62ByyF0uP80SOXaitsABP0TTJX8MhI="},"OpenAMTConfiguration":{"Enabled":false,"MPSServer":"","MPSUser":"","MPSPassword":"","MPSToken":"","CertFileContent":"","CertFileName":"","CertFilePassword":"","DomainName":""},"FeatureFlagSettings":{},"SnapshotInterval":"5m","TemplatesURL":"https://raw.githubusercontent.com/portainer/templates/master/templates-2.0.json","EdgeAgentCheckinInterval":5,"EnableEdgeComputeFeatures":false,"UserSessionTimeout":"8h","KubeconfigExpiry":"0","EnableTelemetry":true,"HelmRepositoryURL":"https://charts.bitnami.com/bitnami","KubectlShellImage":"portainer/kubectl-shell","DisplayDonationHeader":false,"DisplayExternalContributors":false,"EnableHostManagementFeatures":false,"AllowVolumeBrowserForRegularUsers":false,"AllowBindMountsForRegularUsers":false,"AllowPrivilegedModeForRegularUsers":false,"AllowHostNamespaceForRegularUsers":false,"AllowStackManagementForRegularUsers":false,"AllowDeviceMappingForRegularUsers":false,"AllowContainerCapabilitiesForRegularUsers":false}`
|
||||
passphrase = "my secret key"
|
||||
)
|
||||
|
||||
@@ -29,10 +22,10 @@ func secretToEncryptionKey(passphrase string) []byte {
|
||||
func Test_MarshalObjectUnencrypted(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
uuid := uuid.New()
|
||||
uuid := uuid.Must(uuid.NewV4())
|
||||
|
||||
tests := []struct {
|
||||
object any
|
||||
object interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
@@ -64,7 +57,7 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
|
||||
expected: uuid.String(),
|
||||
},
|
||||
{
|
||||
object: map[string]any{"key": "value"},
|
||||
object: map[string]interface{}{"key": "value"},
|
||||
expected: `{"key":"value"}`,
|
||||
},
|
||||
{
|
||||
@@ -80,11 +73,11 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
|
||||
expected: `["1","2","3"]`,
|
||||
},
|
||||
{
|
||||
object: []map[string]any{{"key1": "value1"}, {"key2": "value2"}},
|
||||
object: []map[string]interface{}{{"key1": "value1"}, {"key2": "value2"}},
|
||||
expected: `[{"key1":"value1"},{"key2":"value2"}]`,
|
||||
},
|
||||
{
|
||||
object: []any{1, "2", false, map[string]any{"key1": "value1"}},
|
||||
object: []interface{}{1, "2", false, map[string]interface{}{"key1": "value1"}},
|
||||
expected: `[1,"2",false,{"key1":"value1"}]`,
|
||||
},
|
||||
}
|
||||
@@ -94,7 +87,7 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
data, err := conn.MarshalObject(test.object)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(test.expected, string(data))
|
||||
})
|
||||
}
|
||||
@@ -135,8 +128,8 @@ func Test_UnMarshalObjectUnencrypted(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
var object string
|
||||
err := conn.UnmarshalObject(test.object, &object)
|
||||
require.NoError(t, err)
|
||||
is.Equal(test.expected, object)
|
||||
is.NoError(err)
|
||||
is.Equal(test.expected, string(object))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -167,109 +160,18 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
|
||||
}
|
||||
|
||||
key := secretToEncryptionKey(passphrase)
|
||||
conn := DbConnection{EncryptionKey: key, isEncrypted: true}
|
||||
conn := DbConnection{EncryptionKey: key}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
|
||||
data, err := conn.MarshalObject(test.object)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
var object []byte
|
||||
err = conn.UnmarshalObject(data, &object)
|
||||
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(test.object, object)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NonceSources(t *testing.T) {
|
||||
// ensure that the new go 1.24 NewGCMWithRandomNonce works correctly with
|
||||
// the old way of creating and including the nonce
|
||||
|
||||
encryptOldFn := func(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
|
||||
block, _ := aes.NewCipher(passphrase)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
decryptOldFn := func(encrypted []byte, passphrase []byte) (plaintext []byte, err error) {
|
||||
block, err := aes.NewCipher(passphrase)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating cypher block")
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating GCM")
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(encrypted) < nonceSize {
|
||||
return encrypted, errEncryptedStringTooShort
|
||||
}
|
||||
|
||||
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
|
||||
|
||||
plaintext, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error decrypting text")
|
||||
}
|
||||
|
||||
return plaintext, err
|
||||
}
|
||||
|
||||
encryptNewFn := encrypt
|
||||
decryptNewFn := decrypt
|
||||
|
||||
passphrase := make([]byte, 32)
|
||||
_, err := io.ReadFull(rand.Reader, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
junk := make([]byte, 1024)
|
||||
_, err = io.ReadFull(rand.Reader, junk)
|
||||
require.NoError(t, err)
|
||||
|
||||
junkEnc := make([]byte, base64.StdEncoding.EncodedLen(len(junk)))
|
||||
base64.StdEncoding.Encode(junkEnc, junk)
|
||||
|
||||
cases := [][]byte{
|
||||
[]byte("test"),
|
||||
[]byte("35"),
|
||||
[]byte("9ca4a1dd-a439-4593-b386-a7dfdc2e9fc6"),
|
||||
[]byte(jsonobject),
|
||||
passphrase,
|
||||
junk,
|
||||
junkEnc,
|
||||
}
|
||||
|
||||
for _, plain := range cases {
|
||||
var enc, dec []byte
|
||||
var err error
|
||||
|
||||
enc, err = encryptOldFn(plain, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
dec, err = decryptNewFn(enc, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, plain, dec)
|
||||
|
||||
enc, err = encryptNewFn(plain, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
dec, err = decryptOldFn(enc, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, plain, dec)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
package boltdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
dserrors "github.com/portainer/portainer/api/dataservices/errors"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
type DbTransaction struct {
|
||||
conn *DbConnection
|
||||
tx *bolt.Tx
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) SetServiceName(bucketName string) error {
|
||||
_, err := tx.tx.CreateBucketIfNotExists([]byte(bucketName))
|
||||
return err
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) GetObject(bucketName string, key []byte, object any) error {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
value := bucket.Get(key)
|
||||
if value == nil {
|
||||
return fmt.Errorf("%w (bucket=%s, key=%s)", dserrors.ErrObjectNotFound, bucketName, keyToString(key))
|
||||
}
|
||||
|
||||
return tx.conn.UnmarshalObject(value, object)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) GetRawBytes(bucketName string, key []byte) ([]byte, error) {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
value := bucket.Get(key)
|
||||
if value == nil {
|
||||
return nil, fmt.Errorf("%w (bucket=%s, key=%s)", dserrors.ErrObjectNotFound, bucketName, keyToString(key))
|
||||
}
|
||||
|
||||
if tx.conn.getEncryptionKey() != nil {
|
||||
var err error
|
||||
|
||||
if value, err = decrypt(value, tx.conn.getEncryptionKey()); err != nil {
|
||||
return value, errors.Wrap(err, "Failed decrypting object")
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) KeyExists(bucketName string, key []byte) (bool, error) {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
value := bucket.Get(key)
|
||||
|
||||
return value != nil, nil
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) UpdateObject(bucketName string, key []byte, object any) error {
|
||||
data, err := tx.conn.MarshalObject(object)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
return bucket.Put(key, data)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) DeleteObject(bucketName string, key []byte) error {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
return bucket.Delete(key)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) DeleteAllObjects(bucketName string, obj any, matchingFn func(o any) (id int, ok bool)) error {
|
||||
var ids []int
|
||||
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
cursor := bucket.Cursor()
|
||||
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
|
||||
err := tx.conn.UnmarshalObject(v, &obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if id, ok := matchingFn(obj); ok {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
if err := bucket.Delete(tx.conn.ConvertToKey(id)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) GetNextIdentifier(bucketName string) int {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
id, err := bucket.NextSequence()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifier")
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
return int(id)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) CreateObject(bucketName string, fn func(uint64) (int, any)) error {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
seqId, _ := bucket.NextSequence()
|
||||
id, obj := fn(seqId)
|
||||
|
||||
data, err := tx.conn.MarshalObject(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(tx.conn.ConvertToKey(id), data)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) CreateObjectWithId(bucketName string, id int, obj any) error {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
data, err := tx.conn.MarshalObject(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(tx.conn.ConvertToKey(id), data)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) CreateObjectWithStringId(bucketName string, id []byte, obj any) error {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
data, err := tx.conn.MarshalObject(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(id, data)
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error {
|
||||
bucket := tx.tx.Bucket([]byte(bucketName))
|
||||
|
||||
return bucket.ForEach(func(k []byte, v []byte) error {
|
||||
err := tx.conn.UnmarshalObject(v, obj)
|
||||
if err == nil {
|
||||
obj, err = appendFn(obj)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (tx *DbTransaction) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj any, appendFn func(o any) (any, error)) error {
|
||||
cursor := tx.tx.Bucket([]byte(bucketName)).Cursor()
|
||||
|
||||
for k, v := cursor.Seek(keyPrefix); k != nil && bytes.HasPrefix(k, keyPrefix); k, v = cursor.Next() {
|
||||
err := tx.conn.UnmarshalObject(v, obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
obj, err = appendFn(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
package boltdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testBucketName = "test-bucket"
|
||||
const testId = 1234
|
||||
|
||||
type testStruct struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
func TestTxs(t *testing.T) {
|
||||
conn := DbConnection{Path: t.TempDir()}
|
||||
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Error propagation
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return errors.New("this is an error")
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// Create an object
|
||||
newObj := testStruct{Key: "key", Value: "value"}
|
||||
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
if err := tx.SetServiceName(testBucketName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.CreateObjectWithId(testBucketName, testId, newObj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
obj := testStruct{}
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if obj.Key != newObj.Key || obj.Value != newObj.Value {
|
||||
t.Fatalf("expected %s:%s, got %s:%s instead", newObj.Key, newObj.Value, obj.Key, obj.Value)
|
||||
}
|
||||
|
||||
// Update an object
|
||||
updatedObj := testStruct{Key: "updated-key", Value: "updated-value"}
|
||||
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.UpdateObject(testBucketName, conn.ConvertToKey(testId), &updatedObj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if obj.Key != updatedObj.Key || obj.Value != updatedObj.Value {
|
||||
t.Fatalf("expected %s:%s, got %s:%s instead", updatedObj.Key, updatedObj.Value, obj.Key, obj.Value)
|
||||
}
|
||||
|
||||
// Delete an object
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.DeleteObject(testBucketName, conn.ConvertToKey(testId))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
|
||||
})
|
||||
require.True(t, dataservices.IsErrObjectNotFound(err))
|
||||
|
||||
// Get next identifier
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
id1 := tx.GetNextIdentifier(testBucketName)
|
||||
id2 := tx.GetNextIdentifier(testBucketName)
|
||||
|
||||
if id1+1 != id2 {
|
||||
return errors.New("unexpected identifier sequence")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to write in a read transaction
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.CreateObjectWithId(testBucketName, testId, newObj)
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -8,14 +8,13 @@ import (
|
||||
)
|
||||
|
||||
// NewDatabase should use config options to return a connection to the requested database
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte, compact bool) (connection portainer.Connection, err error) {
|
||||
if storeType == "boltdb" {
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte) (connection portainer.Connection, err error) {
|
||||
switch storeType {
|
||||
case "boltdb":
|
||||
return &boltdb.DbConnection{
|
||||
Path: storePath,
|
||||
EncryptionKey: encryptionKey,
|
||||
Compact: compact,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown storage database: %s", storeType)
|
||||
return nil, fmt.Errorf("unknown storage database: %s", storeType)
|
||||
}
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewDatabase(t *testing.T) {
|
||||
dbPath := filesystem.JoinPaths(t.TempDir(), "test.db")
|
||||
connection, err := NewDatabase("boltdb", dbPath, nil, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, connection)
|
||||
|
||||
_, ok := connection.(*boltdb.DbConnection)
|
||||
require.True(t, ok)
|
||||
|
||||
connection, err = NewDatabase("unknown", dbPath, nil, false)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, connection)
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package models
|
||||
|
||||
type Version struct {
|
||||
SchemaVersion string
|
||||
MigratorCount int
|
||||
Edition int
|
||||
InstanceID string
|
||||
}
|
||||
@@ -1,56 +1,52 @@
|
||||
package apikeyrepository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
dserrors "github.com/portainer/portainer/api/dataservices/errors"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/portainer/portainer/api/dataservices/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// BucketName represents the name of the bucket where this service stores data.
|
||||
const BucketName = "api_key"
|
||||
const (
|
||||
// BucketName represents the name of the bucket where this service stores data.
|
||||
BucketName = "api_key"
|
||||
)
|
||||
|
||||
// Service represents a service for managing api-key data.
|
||||
type Service struct {
|
||||
dataservices.BaseDataService[portainer.APIKey, portainer.APIKeyID]
|
||||
connection portainer.Connection
|
||||
}
|
||||
|
||||
// NewService creates a new instance of a service.
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
if err := connection.SetServiceName(BucketName); err != nil {
|
||||
err := connection.SetServiceName(BucketName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Service{
|
||||
BaseDataService: dataservices.BaseDataService[portainer.APIKey, portainer.APIKeyID]{
|
||||
Bucket: BucketName,
|
||||
Connection: connection,
|
||||
},
|
||||
connection: connection,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAPIKeysByUserID returns a slice containing all the APIKeys a user has access to.
|
||||
func (service *Service) GetAPIKeysByUserID(userID portainer.UserID) ([]portainer.APIKey, error) {
|
||||
result := make([]portainer.APIKey, 0)
|
||||
var result = make([]portainer.APIKey, 0)
|
||||
|
||||
err := service.Connection.GetAll(
|
||||
err := service.connection.GetAll(
|
||||
BucketName,
|
||||
&portainer.APIKey{},
|
||||
func(obj any) (any, error) {
|
||||
func(obj interface{}) (interface{}, error) {
|
||||
record, ok := obj.(*portainer.APIKey)
|
||||
if !ok {
|
||||
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to APIKey object")
|
||||
return nil, fmt.Errorf("failed to convert to APIKey object: %s", obj)
|
||||
logrus.WithField("obj", obj).Errorf("Failed to convert to APIKey object")
|
||||
return nil, fmt.Errorf("Failed to convert to APIKey object: %s", obj)
|
||||
}
|
||||
|
||||
if record.UserID == userID {
|
||||
result = append(result, *record)
|
||||
}
|
||||
|
||||
return &portainer.APIKey{}, nil
|
||||
})
|
||||
|
||||
@@ -59,45 +55,65 @@ func (service *Service) GetAPIKeysByUserID(userID portainer.UserID) ([]portainer
|
||||
|
||||
// GetAPIKeyByDigest returns the API key for the associated digest.
|
||||
// Note: there is a 1-to-1 mapping of api-key and digest
|
||||
func (service *Service) GetAPIKeyByDigest(digest string) (*portainer.APIKey, error) {
|
||||
func (service *Service) GetAPIKeyByDigest(digest []byte) (*portainer.APIKey, error) {
|
||||
var k *portainer.APIKey
|
||||
stop := errors.New("ok")
|
||||
err := service.Connection.GetAll(
|
||||
stop := fmt.Errorf("ok")
|
||||
err := service.connection.GetAll(
|
||||
BucketName,
|
||||
&portainer.APIKey{},
|
||||
func(obj any) (any, error) {
|
||||
func(obj interface{}) (interface{}, error) {
|
||||
key, ok := obj.(*portainer.APIKey)
|
||||
if !ok {
|
||||
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to APIKey object")
|
||||
return nil, fmt.Errorf("failed to convert to APIKey object: %s", obj)
|
||||
logrus.WithField("obj", obj).Errorf("Failed to convert to APIKey object")
|
||||
return nil, fmt.Errorf("Failed to convert to APIKey object: %s", obj)
|
||||
}
|
||||
if key.Digest == digest {
|
||||
if bytes.Equal(key.Digest, digest) {
|
||||
k = key
|
||||
return nil, stop
|
||||
}
|
||||
|
||||
return &portainer.APIKey{}, nil
|
||||
})
|
||||
|
||||
if errors.Is(err, stop) {
|
||||
if err == stop {
|
||||
return k, nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return nil, dserrors.ErrObjectNotFound
|
||||
return nil, errors.ErrObjectNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create creates a new APIKey object.
|
||||
func (service *Service) Create(record *portainer.APIKey) error {
|
||||
return service.Connection.CreateObject(
|
||||
// CreateAPIKey creates a new APIKey object.
|
||||
func (service *Service) CreateAPIKey(record *portainer.APIKey) error {
|
||||
return service.connection.CreateObject(
|
||||
BucketName,
|
||||
func(id uint64) (int, any) {
|
||||
func(id uint64) (int, interface{}) {
|
||||
record.ID = portainer.APIKeyID(id)
|
||||
|
||||
return int(record.ID), record
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// GetAPIKey retrieves an existing APIKey object by api key ID.
|
||||
func (service *Service) GetAPIKey(keyID portainer.APIKeyID) (*portainer.APIKey, error) {
|
||||
var key portainer.APIKey
|
||||
identifier := service.connection.ConvertToKey(int(keyID))
|
||||
|
||||
err := service.connection.GetObject(BucketName, identifier, &key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (service *Service) UpdateAPIKey(key *portainer.APIKey) error {
|
||||
identifier := service.connection.ConvertToKey(int(key.ID))
|
||||
return service.connection.UpdateObject(BucketName, identifier, key)
|
||||
}
|
||||
|
||||
func (service *Service) DeleteAPIKey(ID portainer.APIKeyID) error {
|
||||
identifier := service.connection.ConvertToKey(int(ID))
|
||||
return service.connection.DeleteObject(BucketName, identifier)
|
||||
}
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
package dataservices
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type BaseCRUD[T any, I constraints.Integer] interface {
|
||||
Create(element *T) error
|
||||
Read(ID I) (*T, error)
|
||||
Exists(ID I) (bool, error)
|
||||
ReadAll(predicates ...func(T) bool) ([]T, error)
|
||||
Update(ID I, element *T) error
|
||||
Delete(ID I) error
|
||||
}
|
||||
|
||||
type BaseDataService[T any, I constraints.Integer] struct {
|
||||
Bucket string
|
||||
Connection portainer.Connection
|
||||
}
|
||||
|
||||
func (s *BaseDataService[T, I]) BucketName() string {
|
||||
return s.Bucket
|
||||
}
|
||||
|
||||
func (service *BaseDataService[T, I]) Tx(tx portainer.Transaction) BaseDataServiceTx[T, I] {
|
||||
return BaseDataServiceTx[T, I]{
|
||||
Bucket: service.Bucket,
|
||||
Connection: service.Connection,
|
||||
Tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
func (service BaseDataService[T, I]) Read(ID I) (*T, error) {
|
||||
var element *T
|
||||
|
||||
return element, service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
element, err = service.Tx(tx).Read(ID)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (service BaseDataService[T, I]) Exists(ID I) (bool, error) {
|
||||
var exists bool
|
||||
|
||||
err := service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
exists, err = service.Tx(tx).Exists(ID)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// ReadAll retrieves all the elements that satisfy all the provided predicates.
|
||||
func (service BaseDataService[T, I]) ReadAll(predicates ...func(T) bool) ([]T, error) {
|
||||
var collection = make([]T, 0)
|
||||
|
||||
return collection, service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
collection, err = service.Tx(tx).ReadAll(predicates...)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (service BaseDataService[T, I]) Update(ID I, element *T) error {
|
||||
return service.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).Update(ID, element)
|
||||
})
|
||||
}
|
||||
|
||||
func (service BaseDataService[T, I]) Delete(ID I) error {
|
||||
return service.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).Delete(ID)
|
||||
})
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
package dataservices
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testObject struct {
|
||||
ID int
|
||||
Value int
|
||||
}
|
||||
|
||||
type mockConnection struct {
|
||||
store map[int]testObject
|
||||
|
||||
portainer.Connection
|
||||
}
|
||||
|
||||
func (m mockConnection) UpdateObject(bucket string, key []byte, value any) error {
|
||||
obj := value.(*testObject)
|
||||
|
||||
m.store[obj.ID] = *obj
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConnection) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error {
|
||||
for _, v := range m.store {
|
||||
if _, err := appendFn(&v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConnection) UpdateTx(fn func(portainer.Transaction) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m mockConnection) ViewTx(fn func(portainer.Transaction) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m mockConnection) ConvertToKey(v int) []byte {
|
||||
return []byte(strconv.Itoa(v))
|
||||
}
|
||||
func TestReadAll(t *testing.T) {
|
||||
service := BaseDataService[testObject, int]{
|
||||
Bucket: "testBucket",
|
||||
Connection: mockConnection{store: make(map[int]testObject)},
|
||||
}
|
||||
|
||||
data := []testObject{
|
||||
{ID: 1, Value: 1},
|
||||
{ID: 2, Value: 2},
|
||||
{ID: 3, Value: 3},
|
||||
{ID: 4, Value: 4},
|
||||
{ID: 5, Value: 5},
|
||||
}
|
||||
|
||||
for _, item := range data {
|
||||
err := service.Update(item.ID, &item)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// ReadAll without predicates
|
||||
result, err := service.ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := append([]testObject{}, data...)
|
||||
|
||||
require.ElementsMatch(t, expected, result)
|
||||
|
||||
// ReadAll with predicates
|
||||
hasLowID := func(obj testObject) bool { return obj.ID < 3 }
|
||||
isEven := func(obj testObject) bool { return obj.Value%2 == 0 }
|
||||
|
||||
result, err = service.ReadAll(hasLowID, isEven)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected = slicesx.Filter(expected, hasLowID)
|
||||
expected = slicesx.Filter(expected, isEven)
|
||||
|
||||
require.ElementsMatch(t, expected, result)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user