Merge branch 'basic-functionality' into test-pc
This commit is contained in:
commit
17286ca4cc
@ -1,5 +1,5 @@
|
||||
mtd:
|
||||
services: []
|
||||
services: {}
|
||||
aws:
|
||||
regions: []
|
||||
credentials_path: ./mtdaws/.credentials
|
||||
|
85
main.go
85
main.go
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/thefeli73/polemos/mtdaws"
|
||||
@ -17,45 +18,99 @@ func main() {
|
||||
|
||||
ConfigPath = "config.yaml"
|
||||
|
||||
config := state.LoadConf(ConfigPath)
|
||||
// Initialize the config.Services map
|
||||
var config state.Config
|
||||
config.MTD.Services = make(map[state.CustomUUID]state.Service)
|
||||
|
||||
config = state.LoadConf(ConfigPath)
|
||||
state.SaveConf(ConfigPath, config)
|
||||
|
||||
config = indexInstances(config)
|
||||
|
||||
config = indexAllInstances(config)
|
||||
state.SaveConf(ConfigPath, config)
|
||||
|
||||
// START DOING MTD
|
||||
mtdLoop(config)
|
||||
}
|
||||
|
||||
func indexInstances(config state.Config) state.Config {
|
||||
func mtdLoop(config state.Config) {
|
||||
for true {
|
||||
//TODO: figure out migration (MTD)
|
||||
config = movingTargetDefense(config)
|
||||
state.SaveConf(ConfigPath, config)
|
||||
|
||||
fmt.Println("Sleeping for 5 seconds")
|
||||
time.Sleep(5*time.Second)
|
||||
|
||||
//TODO: proxy commands
|
||||
}
|
||||
}
|
||||
|
||||
func movingTargetDefense(config state.Config) state.Config{
|
||||
|
||||
mtdaws.AWSMoveInstance(config)
|
||||
return config
|
||||
}
|
||||
|
||||
func indexAllInstances(config state.Config) state.Config {
|
||||
fmt.Println("Indexing instances")
|
||||
t := time.Now()
|
||||
|
||||
for _, service := range config.MTD.Services {
|
||||
service.Active = false
|
||||
}
|
||||
|
||||
//index AWS instances
|
||||
awsNewInstanceCounter := 0
|
||||
awsInactiveInstanceCounter := len(config.MTD.Services)
|
||||
awsInstanceCounter := 0
|
||||
awsInstances := mtdaws.GetInstances(config)
|
||||
for _, instance := range awsInstances {
|
||||
cloudID := mtdaws.GetCloudID(instance)
|
||||
ip, err := netip.ParseAddr(instance.PublicIP)
|
||||
if err != nil {
|
||||
fmt.Println("Error converting ip:", err)
|
||||
fmt.Println("Error converting ip:\t", err)
|
||||
continue
|
||||
}
|
||||
newService, found := indexInstance(config, cloudID, ip)
|
||||
var found bool
|
||||
config, found = indexInstance(config, cloudID, ip)
|
||||
if !found {
|
||||
config.MTD.Services = append(config.MTD.Services, newService)
|
||||
state.SaveConf(ConfigPath, config)
|
||||
awsNewInstanceCounter++
|
||||
} else {
|
||||
awsInactiveInstanceCounter--
|
||||
}
|
||||
awsInstanceCounter++
|
||||
}
|
||||
// TODO: Purge instances in config that are not found in the cloud
|
||||
fmt.Printf("Found %d active AWS instances (%d newly added, %d inactive) (took %s)\n",
|
||||
awsInstanceCounter, awsNewInstanceCounter, awsInactiveInstanceCounter, time.Since(t).Round(100*time.Millisecond).String())
|
||||
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func indexInstance(config state.Config, cloudID string, serviceIP netip.Addr) (state.Service, bool) {
|
||||
func indexInstance(config state.Config, cloudID string, serviceIP netip.Addr) (state.Config, bool) {
|
||||
found := false
|
||||
for _, service := range config.MTD.Services {
|
||||
var foundUUID state.CustomUUID
|
||||
for u, service := range config.MTD.Services {
|
||||
if service.CloudID == cloudID {
|
||||
found = true
|
||||
foundUUID = u
|
||||
break;
|
||||
}
|
||||
}
|
||||
u := uuid.New()
|
||||
newService := state.Service{
|
||||
ID: state.CustomUUID(u),
|
||||
CloudID: cloudID,
|
||||
ServiceIP: serviceIP}
|
||||
return newService, found
|
||||
|
||||
if !found {
|
||||
fmt.Println("New instance found:\t", cloudID)
|
||||
u := uuid.New()
|
||||
config.MTD.Services[state.CustomUUID(u)] = state.Service{CloudID: cloudID, ServiceIP: serviceIP, Active: true, AdminEnabled: true}
|
||||
state.SaveConf(ConfigPath, config)
|
||||
|
||||
} else {
|
||||
s := config.MTD.Services[foundUUID]
|
||||
s.Active = true
|
||||
config.MTD.Services[foundUUID] = s
|
||||
state.SaveConf(ConfigPath, config)
|
||||
}
|
||||
return config, found
|
||||
}
|
||||
|
142
mtdaws/mtd.go
Normal file
142
mtdaws/mtd.go
Normal file
@ -0,0 +1,142 @@
|
||||
package mtdaws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
"github.com/google/uuid"
|
||||
"github.com/thefeli73/polemos/state"
|
||||
)
|
||||
|
||||
// AWSUpdateService updates a specified service config to match a newly moved instance
|
||||
func AWSUpdateService(config state.Config, region string, service state.CustomUUID, newInstanceID string) (state.Config) {
|
||||
awsConfig := NewConfig(region, config.AWS.CredentialsPath)
|
||||
svc := ec2.NewFromConfig(awsConfig)
|
||||
instance, err := getInstanceDetailsFromString(svc, newInstanceID)
|
||||
if err != nil {
|
||||
fmt.Println("Error getting instance details:\t", err)
|
||||
return config
|
||||
}
|
||||
|
||||
var publicAddr string
|
||||
if instance.PublicIpAddress != nil {
|
||||
publicAddr = aws.ToString(instance.PublicIpAddress)
|
||||
}
|
||||
formattedinstance := AwsInstance{
|
||||
InstanceID: aws.ToString(instance.InstanceId),
|
||||
Region: region,
|
||||
PublicIP: publicAddr,
|
||||
PrivateIP: aws.ToString(instance.PrivateIpAddress),
|
||||
}
|
||||
cloudid := GetCloudID(formattedinstance)
|
||||
serviceip := netip.MustParseAddr(publicAddr)
|
||||
config.MTD.Services[service] = state.Service{CloudID: cloudid, ServiceIP: serviceip}
|
||||
return config
|
||||
}
|
||||
|
||||
|
||||
// isInstanceRunning returns if an instance is running (true=running)
|
||||
func isInstanceRunning(instance *types.Instance) bool {
|
||||
return instance.State.Name == types.InstanceStateNameRunning
|
||||
}
|
||||
|
||||
// AWSMoveInstance moves a specified instance to a new availability region
|
||||
func AWSMoveInstance(config state.Config) (state.Config) {
|
||||
|
||||
// pseudorandom instance from all services for testing
|
||||
var serviceUUID state.CustomUUID
|
||||
var instance state.Service
|
||||
for key, service := range config.MTD.Services {
|
||||
serviceUUID = key
|
||||
instance = service
|
||||
if !instance.AdminEnabled {continue}
|
||||
if !instance.Active {continue}
|
||||
break
|
||||
}
|
||||
|
||||
fmt.Println("MTD move service:\t", uuid.UUID.String(uuid.UUID(serviceUUID)))
|
||||
|
||||
region, instanceID := DecodeCloudID(instance.CloudID)
|
||||
awsConfig := NewConfig(region, config.AWS.CredentialsPath)
|
||||
svc := ec2.NewFromConfig(awsConfig)
|
||||
|
||||
realInstance, err := getInstanceDetailsFromString(svc, instanceID)
|
||||
if err != nil {
|
||||
fmt.Println("Error getting instance details:\t", err)
|
||||
return config
|
||||
}
|
||||
if !isInstanceRunning(realInstance) {
|
||||
fmt.Println("Error, Instance is not running!")
|
||||
return config
|
||||
}
|
||||
|
||||
//Create image
|
||||
t := time.Now()
|
||||
imageName, err := createImage(svc, instanceID)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating image:\t", err)
|
||||
return config
|
||||
}
|
||||
fmt.Printf("Created image:\t\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String())
|
||||
|
||||
// Wait for image
|
||||
t = time.Now()
|
||||
err = waitForImageReady(svc, imageName, 5*time.Minute)
|
||||
if err != nil {
|
||||
fmt.Println("Error waiting for image to be ready:\t", err)
|
||||
return config
|
||||
}
|
||||
fmt.Printf("Image is ready:\t\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String())
|
||||
|
||||
// Launch new instance
|
||||
t = time.Now()
|
||||
newInstanceID, err := launchInstance(svc, realInstance, imageName, region)
|
||||
if err != nil {
|
||||
fmt.Println("Error launching instance:\t", err)
|
||||
return config
|
||||
}
|
||||
fmt.Printf("Launched new instance:\t%s (took %s)\n", newInstanceID, time.Since(t).Round(100*time.Millisecond).String())
|
||||
|
||||
// Terminate old instance
|
||||
t = time.Now()
|
||||
err = terminateInstance(svc, instanceID)
|
||||
if err != nil {
|
||||
fmt.Println("Error terminating instance:\t", err)
|
||||
return config
|
||||
}
|
||||
fmt.Printf("Killed old instance:\t%s (took %s)\n", instanceID, time.Since(t).Round(100*time.Millisecond).String())
|
||||
|
||||
// Deregister old image
|
||||
t = time.Now()
|
||||
image, err := describeImage(svc, imageName)
|
||||
if err != nil {
|
||||
fmt.Println("Error describing image:\t", err)
|
||||
return config
|
||||
}
|
||||
err = deregisterImage(svc, imageName)
|
||||
if err != nil {
|
||||
fmt.Println("Error deregistering image:\t", err)
|
||||
return config
|
||||
}
|
||||
fmt.Printf("Deregistered image:\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String())
|
||||
|
||||
// Delete old snapshot
|
||||
t = time.Now()
|
||||
if len(image.BlockDeviceMappings) > 0 {
|
||||
snapshotID := aws.ToString(image.BlockDeviceMappings[0].Ebs.SnapshotId)
|
||||
err = deleteSnapshot(svc, snapshotID)
|
||||
if err != nil {
|
||||
fmt.Println("Error deleting snapshot:\t", err)
|
||||
return config
|
||||
}
|
||||
fmt.Printf("Deleted snapshot:\t%s (took %s)\n", snapshotID, time.Since(t).Round(100*time.Millisecond).String())
|
||||
}
|
||||
|
||||
AWSUpdateService(config, region, serviceUUID, newInstanceID)
|
||||
|
||||
return config
|
||||
}
|
217
mtdaws/utils.go
217
mtdaws/utils.go
@ -2,8 +2,12 @@ package mtdaws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
@ -36,6 +40,17 @@ func GetCloudID(instance AwsInstance) string {
|
||||
return "aws_" + instance.Region + "_" + instance.InstanceID
|
||||
}
|
||||
|
||||
// DecodeCloudID returns information to locate instance in aws
|
||||
func DecodeCloudID(cloudID string) (string, string) {
|
||||
split := strings.Split(cloudID, "_")
|
||||
if len(split) != 3 {
|
||||
panic(cloudID + " does not decode as AWS CloudID")
|
||||
}
|
||||
region := split[1]
|
||||
instanceID := split[2]
|
||||
return region, instanceID
|
||||
}
|
||||
|
||||
// GetInstances scans all configured regions for instances and add them to services
|
||||
func GetInstances(config state.Config) []AwsInstance {
|
||||
awsInstances := []AwsInstance{}
|
||||
@ -46,8 +61,6 @@ func GetInstances(config state.Config) []AwsInstance {
|
||||
fmt.Println("Error listing instances:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
//fmt.Println("Listing instances in region:", region)
|
||||
for _, instance := range instances {
|
||||
var publicAddr string
|
||||
if instance.PublicIpAddress != nil {
|
||||
@ -63,25 +76,12 @@ func GetInstances(config state.Config) []AwsInstance {
|
||||
return awsInstances
|
||||
}
|
||||
|
||||
// PrintInstanceInfo prints info about a specific instance in a region
|
||||
func PrintInstanceInfo(instance *types.Instance) {
|
||||
fmt.Println("\tInstance ID:", aws.ToString(instance.InstanceId))
|
||||
fmt.Println("\t\tInstance Type:", string(instance.InstanceType))
|
||||
fmt.Println("\t\tAMI ID:", aws.ToString(instance.ImageId))
|
||||
fmt.Println("\t\tState:", string(instance.State.Name))
|
||||
fmt.Println("\t\tAvailability Zone:", aws.ToString(instance.Placement.AvailabilityZone))
|
||||
if instance.PublicIpAddress != nil {
|
||||
fmt.Println("\t\tPublic IP Address:", aws.ToString(instance.PublicIpAddress))
|
||||
}
|
||||
fmt.Println("\t\tPrivate IP Address:", aws.ToString(instance.PrivateIpAddress))
|
||||
}
|
||||
|
||||
// Instances returns all instances for a config i.e. a region
|
||||
func Instances(config aws.Config) ([]*types.Instance, error) {
|
||||
func Instances(config aws.Config) ([]types.Instance, error) {
|
||||
svc := ec2.NewFromConfig(config)
|
||||
|
||||
input := &ec2.DescribeInstancesInput{}
|
||||
var instances []*types.Instance
|
||||
var instances []types.Instance
|
||||
|
||||
paginator := ec2.NewDescribeInstancesPaginator(svc, input)
|
||||
|
||||
@ -93,10 +93,189 @@ func Instances(config aws.Config) ([]*types.Instance, error) {
|
||||
|
||||
for _, reservation := range page.Reservations {
|
||||
for _, instance := range reservation.Instances {
|
||||
instances = append(instances, &instance)
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// createImage will create an AMI (amazon machine image) of a given instance
|
||||
func createImage(svc *ec2.Client, instanceID string) (string, error) {
|
||||
input := &ec2.CreateImageInput{
|
||||
InstanceId: aws.String(instanceID),
|
||||
Name: aws.String(fmt.Sprintf("backup-%s-%d", instanceID, time.Now().Unix())),
|
||||
Description: aws.String("Migration backup"),
|
||||
NoReboot: aws.Bool(true),
|
||||
}
|
||||
|
||||
output, err := svc.CreateImage(context.TODO(), input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return aws.ToString(output.ImageId), nil
|
||||
}
|
||||
|
||||
// waitForImageReady polls every second to see if the image is ready
|
||||
func waitForImageReady(svc *ec2.Client, imageID string, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.New("timed out waiting for image to be ready")
|
||||
case <-time.After(1 * time.Second):
|
||||
input := &ec2.DescribeImagesInput{
|
||||
ImageIds: []string{imageID},
|
||||
}
|
||||
output, err := svc.DescribeImages(ctx, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(output.Images) > 0 && output.Images[0].State == types.ImageStateAvailable {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// launchInstance launches a instance IN RANDOM AVAILABILITY ZONE within the same region, based on an oldInstance and AMI (duplicating the instance)
|
||||
func launchInstance(svc *ec2.Client, oldInstance *types.Instance, imageID string, region string) (string, error) {
|
||||
securityGroupIds := make([]string, len(oldInstance.SecurityGroups))
|
||||
for i, sg := range oldInstance.SecurityGroups {
|
||||
securityGroupIds[i] = aws.ToString(sg.GroupId)
|
||||
}
|
||||
// TODO: select random zone that is not the current one.
|
||||
availabilityZone, err := getRandomDifferentAvailabilityZone(svc, oldInstance, region)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
input := &ec2.RunInstancesInput{
|
||||
ImageId: aws.String(imageID),
|
||||
InstanceType: oldInstance.InstanceType,
|
||||
MinCount: aws.Int32(1),
|
||||
MaxCount: aws.Int32(1),
|
||||
KeyName: oldInstance.KeyName,
|
||||
SecurityGroupIds: securityGroupIds,
|
||||
Placement: &types.Placement{
|
||||
AvailabilityZone: aws.String(availabilityZone),
|
||||
},
|
||||
}
|
||||
|
||||
output, err := svc.RunInstances(context.TODO(), input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// TODO: save/index config for the new instance
|
||||
|
||||
return aws.ToString(output.Instances[0].InstanceId), nil
|
||||
}
|
||||
|
||||
// getRandomDifferentAvailabilityZone fetches all AZ from the same region as the instance and returns a random AZ that is not equal to the one used by the instance
|
||||
func getRandomDifferentAvailabilityZone(svc *ec2.Client, instance *types.Instance, region string) (string, error) {
|
||||
// Seed the random generator
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Get the current availability zone of the instance
|
||||
currentAZ := aws.ToString(instance.Placement.AvailabilityZone)
|
||||
|
||||
// Describe availability zones in the region
|
||||
input := &ec2.DescribeAvailabilityZonesInput{
|
||||
Filters: []types.Filter{
|
||||
{
|
||||
Name: aws.String("region-name"),
|
||||
Values: []string{region},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
output, err := svc.DescribeAvailabilityZones(context.TODO(), input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Filter out the current availability zone
|
||||
availableAZs := []string{}
|
||||
for _, az := range output.AvailabilityZones {
|
||||
if aws.ToString(az.ZoneName) != currentAZ {
|
||||
availableAZs = append(availableAZs, aws.ToString(az.ZoneName))
|
||||
}
|
||||
}
|
||||
|
||||
// If no other availability zones are available, return an error
|
||||
if len(availableAZs) == 0 {
|
||||
return "", errors.New("no other availability zones available")
|
||||
}
|
||||
|
||||
// Select a random availability zone from the remaining ones
|
||||
randomIndex := rand.Intn(len(availableAZs))
|
||||
randomAZ := availableAZs[randomIndex]
|
||||
return randomAZ, nil
|
||||
}
|
||||
|
||||
|
||||
// terminateInstance kills an instance by id
|
||||
func terminateInstance(svc *ec2.Client, instanceID string) error {
|
||||
input := &ec2.TerminateInstancesInput{
|
||||
InstanceIds: []string{instanceID},
|
||||
}
|
||||
_, err := svc.TerminateInstances(context.TODO(), input)
|
||||
return err
|
||||
}
|
||||
|
||||
// describeImage gets info about an image from string
|
||||
func describeImage(svc *ec2.Client, imageID string) (*types.Image, error) {
|
||||
input := &ec2.DescribeImagesInput{
|
||||
ImageIds: []string{imageID},
|
||||
}
|
||||
|
||||
output, err := svc.DescribeImages(context.TODO(), input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(output.Images) == 0 {
|
||||
return nil, errors.New("image not found")
|
||||
}
|
||||
|
||||
return &output.Images[0], nil
|
||||
}
|
||||
|
||||
// deregisterImage deletes the AMI passed as string
|
||||
func deregisterImage(svc *ec2.Client, imageID string) error {
|
||||
input := &ec2.DeregisterImageInput{
|
||||
ImageId: aws.String(imageID),
|
||||
}
|
||||
|
||||
_, err := svc.DeregisterImage(context.TODO(), input)
|
||||
return err
|
||||
}
|
||||
|
||||
// deleteSnapshot deletes the snapshot passed as string
|
||||
func deleteSnapshot(svc *ec2.Client, snapshotID string) error {
|
||||
input := &ec2.DeleteSnapshotInput{
|
||||
SnapshotId: aws.String(snapshotID),
|
||||
}
|
||||
|
||||
_, err := svc.DeleteSnapshot(context.TODO(), input)
|
||||
return err
|
||||
}
|
||||
|
||||
// getInstanceDetailsFromString does what the name says
|
||||
func getInstanceDetailsFromString(svc *ec2.Client, instanceID string) (*types.Instance, error) {
|
||||
input := &ec2.DescribeInstancesInput{
|
||||
InstanceIds: []string{instanceID},
|
||||
}
|
||||
|
||||
output, err := svc.DescribeInstances(context.TODO(), input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &output.Reservations[0].Instances[0], nil
|
||||
}
|
||||
|
@ -17,13 +17,15 @@ type Config struct {
|
||||
}
|
||||
|
||||
type mtdconf struct {
|
||||
Services []Service `yaml:"services"`
|
||||
Services map[CustomUUID]Service `yaml:"services"`
|
||||
|
||||
}
|
||||
|
||||
// Service contains all necessary information about a service to identify it in the cloud as well as configuring a proxy for it
|
||||
type Service struct {
|
||||
ID CustomUUID `yaml:"id"`
|
||||
CloudID string `yaml:"cloud_id"`
|
||||
AdminEnabled bool `yaml:"admin_enabled"`
|
||||
Active bool `yaml:"active"`
|
||||
EntryIP netip.Addr `yaml:"entry_ip"`
|
||||
EntryPort uint16 `yaml:"entry_port"`
|
||||
ServiceIP netip.Addr `yaml:"service_ip"`
|
||||
|
Loading…
x
Reference in New Issue
Block a user