Merge branch 'main' into dry

This commit is contained in:
schulze 2023-05-02 16:10:03 +02:00
commit 5505a33557
7 changed files with 515 additions and 45 deletions

2
.gitignore vendored
View File

@ -20,3 +20,5 @@
config.yaml
polemos

View File

@ -1,5 +1,6 @@
mtd:
services: []
services: {}
management_port: 14000
aws:
regions: []
credentials_path: ./mtdaws/.credentials

105
main.go
View File

@ -3,9 +3,11 @@ package main
import (
"fmt"
"net/netip"
"time"
"github.com/google/uuid"
"github.com/thefeli73/polemos/mtdaws"
"github.com/thefeli73/polemos/pcsdk"
"github.com/thefeli73/polemos/state"
)
@ -17,45 +19,120 @@ func main() {
ConfigPath = "config.yaml"
config := state.LoadConf(ConfigPath)
// Initialize the config.Services map
var config state.Config
config.MTD.Services = make(map[state.CustomUUID]state.Service)
config = state.LoadConf(ConfigPath)
state.SaveConf(ConfigPath, config)
config = indexInstances(config)
config = indexAllInstances(config)
state.SaveConf(ConfigPath, config)
// CREATE TUNNELS
createTunnels(config)
// START DOING MTD
mtdLoop(config)
}
func indexInstances(config state.Config) state.Config {
func mtdLoop(config state.Config) {
for true {
//TODO: figure out migration (MTD)
config = movingTargetDefense(config)
state.SaveConf(ConfigPath, config)
fmt.Println("Sleeping for 1 minute")
time.Sleep(1*time.Minute)
//TODO: proxy commands
}
}
func movingTargetDefense(config state.Config) state.Config{
mtdaws.AWSMoveInstance(config)
return config
}
func indexAllInstances(config state.Config) state.Config {
fmt.Println("Indexing instances")
t := time.Now()
for _, service := range config.MTD.Services {
service.Active = false
}
//index AWS instances
awsNewInstanceCounter := 0
awsInactiveInstanceCounter := len(config.MTD.Services)
awsInstanceCounter := 0
awsInstances := mtdaws.GetInstances(config)
for _, instance := range awsInstances {
cloudID := mtdaws.GetCloudID(instance)
ip, err := netip.ParseAddr(instance.PublicIP)
if err != nil {
fmt.Println("Error converting ip:", err)
fmt.Println("Error converting ip:\t", err)
continue
}
newService, found := indexInstance(config, cloudID, ip)
var found bool
config, found = indexInstance(config, cloudID, ip)
if !found {
config.MTD.Services = append(config.MTD.Services, newService)
state.SaveConf(ConfigPath, config)
awsNewInstanceCounter++
} else {
awsInactiveInstanceCounter--
}
awsInstanceCounter++
}
// TODO: Purge instances in config that are not found in the cloud
fmt.Printf("Found %d active AWS instances (%d newly added, %d inactive) (took %s)\n",
awsInstanceCounter, awsNewInstanceCounter, awsInactiveInstanceCounter, time.Since(t).Round(100*time.Millisecond).String())
return config
}
func indexInstance(config state.Config, cloudID string, serviceIP netip.Addr) (state.Service, bool) {
func createTunnels(config state.Config) {
for serviceUUID, service := range config.MTD.Services {
if service.AdminEnabled && service.Active {
s := pcsdk.NewCommandStatus()
err := s.Execute(netip.AddrPortFrom(service.EntryIP, config.MTD.ManagementPort))
if err != nil {
continue
}
// Reconfigure Proxy to new instance
c := pcsdk.NewCommandCreate(service.EntryPort, service.ServicePort, service.ServiceIP, serviceUUID)
err = c.Execute(netip.AddrPortFrom(service.EntryIP, config.MTD.ManagementPort))
if err != nil {
continue
}
}
}
}
func indexInstance(config state.Config, cloudID string, serviceIP netip.Addr) (state.Config, bool) {
found := false
for _, service := range config.MTD.Services {
var foundUUID state.CustomUUID
for u, service := range config.MTD.Services {
if service.CloudID == cloudID {
found = true
foundUUID = u
break;
}
}
if !found {
fmt.Println("New instance found:\t", cloudID)
u := uuid.New()
newService := state.Service{
ID: state.CustomUUID(u),
CloudID: cloudID,
ServiceIP: serviceIP}
return newService, found
config.MTD.Services[state.CustomUUID(u)] = state.Service{CloudID: cloudID, ServiceIP: serviceIP, Active: true, AdminEnabled: true}
state.SaveConf(ConfigPath, config)
} else {
s := config.MTD.Services[foundUUID]
s.Active = true
config.MTD.Services[foundUUID] = s
state.SaveConf(ConfigPath, config)
}
return config, found
}

181
mtdaws/mtd.go Normal file
View File

@ -0,0 +1,181 @@
package mtdaws
import (
"fmt"
"net/netip"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/google/uuid"
"github.com/thefeli73/polemos/pcsdk"
"github.com/thefeli73/polemos/state"
)
// AWSMoveInstance moves a specified instance to a new availability region
func AWSMoveInstance(config state.Config) (state.Config) {
// pseudorandom instance from all services for testing
var serviceUUID state.CustomUUID
var instance state.Service
for key, service := range config.MTD.Services {
serviceUUID = key
instance = service
if !instance.AdminEnabled {continue}
if !instance.Active {continue}
break
}
fmt.Println("MTD move service:\t", uuid.UUID.String(uuid.UUID(serviceUUID)))
// Test Proxy Connection
t := time.Now()
s := pcsdk.NewCommandStatus()
err := s.Execute(netip.AddrPortFrom(instance.EntryIP, config.MTD.ManagementPort))
if err != nil {
fmt.Printf("error executing test command: %s\n", err)
return config
}
fmt.Printf("Proxy Tested. (took %s)\n", time.Since(t).Round(100*time.Millisecond).String())
region, instanceID := DecodeCloudID(instance.CloudID)
awsConfig := NewConfig(region, config.AWS.CredentialsPath)
svc := ec2.NewFromConfig(awsConfig)
realInstance, err := getInstanceDetailsFromString(svc, instanceID)
if err != nil {
fmt.Println("Error getting instance details:\t", err)
return config
}
if !isInstanceRunning(realInstance) {
fmt.Println("Error, Instance is not running!")
return config
}
//Create image
t = time.Now()
imageName, err := createImage(svc, instanceID)
if err != nil {
fmt.Println("Error creating image:\t", err)
return config
}
fmt.Printf("Created image:\t\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String())
// Wait for image
t = time.Now()
err = waitForImageReady(svc, imageName, 5*time.Minute)
if err != nil {
fmt.Println("Error waiting for image to be ready:\t", err)
return config
}
fmt.Printf("Image is ready:\t\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String())
// Launch new instance
t = time.Now()
newInstanceID, err := launchInstance(svc, realInstance, imageName, region)
if err != nil {
fmt.Println("Error launching instance:\t", err)
return config
}
fmt.Printf("Launched new instance:\t%s (took %s)\n", newInstanceID, time.Since(t).Round(100*time.Millisecond).String())
// Wait for instance
t = time.Now()
err = waitForInstanceReady(svc, newInstanceID, 5*time.Minute)
if err != nil {
fmt.Println("Error waiting for instance to be ready:\t", err)
return config
}
fmt.Printf("instance is ready:\t\t%s (took %s)\n", newInstanceID, time.Since(t).Round(100*time.Millisecond).String())
// update local config to match new instance
config = AWSUpdateService(config, region, serviceUUID, newInstanceID)
// Reconfigure Proxy to new instance
t = time.Now()
m := pcsdk.NewCommandModify(config.MTD.Services[serviceUUID].ServicePort, config.MTD.Services[serviceUUID].ServiceIP, serviceUUID)
err = m.Execute(netip.AddrPortFrom(config.MTD.Services[serviceUUID].EntryIP, config.MTD.ManagementPort))
if err != nil {
fmt.Printf("error executing modify command: %s\n", err)
return config
}
fmt.Printf("Proxy modified. (took %s)\n", time.Since(t).Round(100*time.Millisecond).String())
// take care of old instance, deregister image and delete snapshot
cleanupAWS(svc, config, instanceID, imageName)
return config
}
// AWSUpdateService updates a specified service config to match a newly moved instance
func AWSUpdateService(config state.Config, region string, service state.CustomUUID, newInstanceID string) (state.Config) {
awsConfig := NewConfig(region, config.AWS.CredentialsPath)
svc := ec2.NewFromConfig(awsConfig)
instance, err := getInstanceDetailsFromString(svc, newInstanceID)
if err != nil {
fmt.Println("Error getting instance details:\t", err)
return config
}
var publicAddr string
if instance.PublicIpAddress != nil {
publicAddr = aws.ToString(instance.PublicIpAddress)
}
formattedinstance := AwsInstance{
InstanceID: aws.ToString(instance.InstanceId),
Region: region,
PublicIP: publicAddr,
PrivateIP: aws.ToString(instance.PrivateIpAddress),
}
cloudid := GetCloudID(formattedinstance)
serviceip := netip.MustParseAddr(publicAddr)
s := config.MTD.Services[service]
s.CloudID = cloudid
s.ServiceIP = serviceip
config.MTD.Services[service] = s
return config
}
// isInstanceRunning returns if an instance is running (true=running)
func isInstanceRunning(instance *types.Instance) bool {
return instance.State.Name == types.InstanceStateNameRunning
}
// cleanupAWS terminates the old instance, deregisters the image and deletes the old snapshot
func cleanupAWS(svc *ec2.Client, config state.Config, instanceID string, imageName string) state.Config {
// Terminate old instance
t := time.Now()
err := terminateInstance(svc, instanceID)
if err != nil {
fmt.Println("Error terminating instance:\t", err)
return config
}
fmt.Printf("Killed old instance:\t%s (took %s)\n", instanceID, time.Since(t).Round(100*time.Millisecond).String())
// Deregister old image
t = time.Now()
image, err := describeImage(svc, imageName)
if err != nil {
fmt.Println("Error describing image:\t", err)
return config
}
err = deregisterImage(svc, imageName)
if err != nil {
fmt.Println("Error deregistering image:\t", err)
return config
}
fmt.Printf("Deregistered image:\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String())
// Delete old snapshot
t = time.Now()
if len(image.BlockDeviceMappings) > 0 {
snapshotID := aws.ToString(image.BlockDeviceMappings[0].Ebs.SnapshotId)
err = deleteSnapshot(svc, snapshotID)
if err != nil {
fmt.Println("Error deleting snapshot:\t", err)
return config
}
fmt.Printf("Deleted snapshot:\t%s (took %s)\n", snapshotID, time.Since(t).Round(100*time.Millisecond).String())
}
return config
}

View File

@ -2,8 +2,12 @@ package mtdaws
import (
"context"
"errors"
"fmt"
"math/rand"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
@ -36,6 +40,17 @@ func GetCloudID(instance AwsInstance) string {
return "aws_" + instance.Region + "_" + instance.InstanceID
}
// DecodeCloudID returns information to locate instance in aws
func DecodeCloudID(cloudID string) (string, string) {
split := strings.Split(cloudID, "_")
if len(split) != 3 {
panic(cloudID + " does not decode as AWS CloudID")
}
region := split[1]
instanceID := split[2]
return region, instanceID
}
// GetInstances scans all configured regions for instances and add them to services
func GetInstances(config state.Config) []AwsInstance {
awsInstances := []AwsInstance{}
@ -46,8 +61,6 @@ func GetInstances(config state.Config) []AwsInstance {
fmt.Println("Error listing instances:", err)
continue
}
//fmt.Println("Listing instances in region:", region)
for _, instance := range instances {
var publicAddr string
if instance.PublicIpAddress != nil {
@ -63,25 +76,12 @@ func GetInstances(config state.Config) []AwsInstance {
return awsInstances
}
// PrintInstanceInfo prints info about a specific instance in a region
func PrintInstanceInfo(instance *types.Instance) {
fmt.Println("\tInstance ID:", aws.ToString(instance.InstanceId))
fmt.Println("\t\tInstance Type:", string(instance.InstanceType))
fmt.Println("\t\tAMI ID:", aws.ToString(instance.ImageId))
fmt.Println("\t\tState:", string(instance.State.Name))
fmt.Println("\t\tAvailability Zone:", aws.ToString(instance.Placement.AvailabilityZone))
if instance.PublicIpAddress != nil {
fmt.Println("\t\tPublic IP Address:", aws.ToString(instance.PublicIpAddress))
}
fmt.Println("\t\tPrivate IP Address:", aws.ToString(instance.PrivateIpAddress))
}
// Instances returns all instances for a config i.e. a region
func Instances(config aws.Config) ([]*types.Instance, error) {
func Instances(config aws.Config) ([]types.Instance, error) {
svc := ec2.NewFromConfig(config)
input := &ec2.DescribeInstancesInput{}
var instances []*types.Instance
var instances []types.Instance
paginator := ec2.NewDescribeInstancesPaginator(svc, input)
@ -93,10 +93,220 @@ func Instances(config aws.Config) ([]*types.Instance, error) {
for _, reservation := range page.Reservations {
for _, instance := range reservation.Instances {
instances = append(instances, &instance)
instances = append(instances, instance)
}
}
}
return instances, nil
}
// createImage will create an AMI (amazon machine image) of a given instance
func createImage(svc *ec2.Client, instanceID string) (string, error) {
input := &ec2.CreateImageInput{
InstanceId: aws.String(instanceID),
Name: aws.String(fmt.Sprintf("backup-%s-%d", instanceID, time.Now().Unix())),
Description: aws.String("Migration backup"),
NoReboot: aws.Bool(true),
}
output, err := svc.CreateImage(context.TODO(), input)
if err != nil {
return "", err
}
return aws.ToString(output.ImageId), nil
}
// waitForImageReady polls every second to see if the image is ready
func waitForImageReady(svc *ec2.Client, imageID string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for {
select {
case <-ctx.Done():
return errors.New("timed out waiting for image to be ready")
case <-time.After(1 * time.Second):
input := &ec2.DescribeImagesInput{
ImageIds: []string{imageID},
}
output, err := svc.DescribeImages(ctx, input)
if err != nil {
return err
}
if len(output.Images) > 0 && output.Images[0].State == types.ImageStateAvailable {
return nil
}
}
}
}
return instances, nil
// waitForInstanceReady waits for the newly launched instance to be running and ready
func waitForInstanceReady(svc *ec2.Client, newInstanceID string, timeout time.Duration) error {
// Wait for the instance to be running
waitInput := &ec2.DescribeInstancesInput{
InstanceIds: []string{newInstanceID},
}
waiter := ec2.NewInstanceRunningWaiter(svc)
err := waiter.Wait(context.TODO(), waitInput, timeout)
if err != nil {
return err
}
return nil
}
// launchInstance launches a instance IN RANDOM AVAILABILITY ZONE within the same region, based on an oldInstance and AMI (duplicating the instance)
func launchInstance(svc *ec2.Client, oldInstance *types.Instance, imageID string, region string) (string, error) {
securityGroupIds := make([]string, len(oldInstance.SecurityGroups))
for i, sg := range oldInstance.SecurityGroups {
securityGroupIds[i] = aws.ToString(sg.GroupId)
}
availabilityZone, err := getRandomDifferentAvailabilityZone(svc, oldInstance, region)
if err != nil {
return "", err
}
var nameTag string
for _, tag := range oldInstance.Tags {
if aws.ToString(tag.Key) == "Name" {
nameTag = aws.ToString(tag.Value)
break
}
}
input := &ec2.RunInstancesInput{
ImageId: aws.String(imageID),
InstanceType: oldInstance.InstanceType,
MinCount: aws.Int32(1),
MaxCount: aws.Int32(1),
KeyName: oldInstance.KeyName,
SecurityGroupIds: securityGroupIds,
Placement: &types.Placement{
AvailabilityZone: aws.String(availabilityZone),
},
TagSpecifications: []types.TagSpecification{
{
ResourceType: types.ResourceTypeInstance,
Tags: []types.Tag{
{
Key: aws.String("Name"),
Value: aws.String(nameTag),
},
},
},
},
}
output, err := svc.RunInstances(context.TODO(), input)
if err != nil {
return "", err
}
// TODO: save/index config for the new instance
return aws.ToString(output.Instances[0].InstanceId), nil
}
// getRandomDifferentAvailabilityZone fetches all AZ from the same region as the instance and returns a random AZ that is not equal to the one used by the instance
func getRandomDifferentAvailabilityZone(svc *ec2.Client, instance *types.Instance, region string) (string, error) {
// Seed the random generator
rand.Seed(time.Now().UnixNano())
// Get the current availability zone of the instance
currentAZ := aws.ToString(instance.Placement.AvailabilityZone)
// Describe availability zones in the region
input := &ec2.DescribeAvailabilityZonesInput{
Filters: []types.Filter{
{
Name: aws.String("region-name"),
Values: []string{region},
},
},
}
output, err := svc.DescribeAvailabilityZones(context.TODO(), input)
if err != nil {
return "", err
}
// Filter out the current availability zone
availableAZs := []string{}
for _, az := range output.AvailabilityZones {
if aws.ToString(az.ZoneName) != currentAZ {
availableAZs = append(availableAZs, aws.ToString(az.ZoneName))
}
}
// If no other availability zones are available, return an error
if len(availableAZs) == 0 {
return "", errors.New("no other availability zones available")
}
// Select a random availability zone from the remaining ones
randomIndex := rand.Intn(len(availableAZs))
randomAZ := availableAZs[randomIndex]
return randomAZ, nil
}
// terminateInstance kills an instance by id
func terminateInstance(svc *ec2.Client, instanceID string) error {
input := &ec2.TerminateInstancesInput{
InstanceIds: []string{instanceID},
}
_, err := svc.TerminateInstances(context.TODO(), input)
return err
}
// describeImage gets info about an image from string
func describeImage(svc *ec2.Client, imageID string) (*types.Image, error) {
input := &ec2.DescribeImagesInput{
ImageIds: []string{imageID},
}
output, err := svc.DescribeImages(context.TODO(), input)
if err != nil {
return nil, err
}
if len(output.Images) == 0 {
return nil, errors.New("image not found")
}
return &output.Images[0], nil
}
// deregisterImage deletes the AMI passed as string
func deregisterImage(svc *ec2.Client, imageID string) error {
input := &ec2.DeregisterImageInput{
ImageId: aws.String(imageID),
}
_, err := svc.DeregisterImage(context.TODO(), input)
return err
}
// deleteSnapshot deletes the snapshot passed as string
func deleteSnapshot(svc *ec2.Client, snapshotID string) error {
input := &ec2.DeleteSnapshotInput{
SnapshotId: aws.String(snapshotID),
}
_, err := svc.DeleteSnapshot(context.TODO(), input)
return err
}
// getInstanceDetailsFromString does what the name says
func getInstanceDetailsFromString(svc *ec2.Client, instanceID string) (*types.Instance, error) {
input := &ec2.DescribeInstancesInput{
InstanceIds: []string{instanceID},
}
output, err := svc.DescribeInstances(context.TODO(), input)
if err != nil {
return nil, err
}
return &output.Reservations[0].Instances[0], nil
}

View File

@ -57,11 +57,7 @@ func (p Proxy) execute(c command) (string, error) {
if err != nil {
return "", errors.New(fmt.Sprintf("error making http request: %s\n", err))
}
fmt.Println(res)
body, err := io.ReadAll(res.Body)
fmt.Println(string(body))
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return "", errors.New(fmt.Sprintf("error reading response: %s\n", err))
}

View File

@ -17,13 +17,16 @@ type Config struct {
}
type mtdconf struct {
Services []Service `yaml:"services"`
Services map[CustomUUID]Service `yaml:"services"`
ManagementPort uint16 `yaml:"management_port"`
}
// Service contains all necessary information about a service to identify it in the cloud as well as configuring a proxy for it
type Service struct {
ID CustomUUID `yaml:"id"`
CloudID string `yaml:"cloud_id"`
AdminEnabled bool `yaml:"admin_enabled"`
Active bool `yaml:"active"`
EntryIP netip.Addr `yaml:"entry_ip"`
EntryPort uint16 `yaml:"entry_port"`
ServiceIP netip.Addr `yaml:"service_ip"`