diff --git a/config.default.yaml b/config.default.yaml index 4420f39..77240ff 100644 --- a/config.default.yaml +++ b/config.default.yaml @@ -1,5 +1,5 @@ mtd: - services: [] + services: {} aws: regions: [] credentials_path: ./mtdaws/.credentials diff --git a/main.go b/main.go index 3db5399..0387f45 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "net/netip" + "time" "github.com/google/uuid" "github.com/thefeli73/polemos/mtdaws" @@ -17,45 +18,99 @@ func main() { ConfigPath = "config.yaml" - config := state.LoadConf(ConfigPath) + // Initialize the config.Services map + var config state.Config + config.MTD.Services = make(map[state.CustomUUID]state.Service) + + config = state.LoadConf(ConfigPath) state.SaveConf(ConfigPath, config) - config = indexInstances(config) + config = indexAllInstances(config) + state.SaveConf(ConfigPath, config) + + // START DOING MTD + mtdLoop(config) } -func indexInstances(config state.Config) state.Config { +func mtdLoop(config state.Config) { + for true { + //TODO: figure out migration (MTD) + config = movingTargetDefense(config) + state.SaveConf(ConfigPath, config) + + fmt.Println("Sleeping for 5 seconds") + time.Sleep(5*time.Second) + + //TODO: proxy commands + } +} + +func movingTargetDefense(config state.Config) state.Config{ + + mtdaws.AWSMoveInstance(config) + return config +} + +func indexAllInstances(config state.Config) state.Config { fmt.Println("Indexing instances") + t := time.Now() + + for _, service := range config.MTD.Services { + service.Active = false + } //index AWS instances + awsNewInstanceCounter := 0 + awsInactiveInstanceCounter := len(config.MTD.Services) + awsInstanceCounter := 0 awsInstances := mtdaws.GetInstances(config) for _, instance := range awsInstances { cloudID := mtdaws.GetCloudID(instance) ip, err := netip.ParseAddr(instance.PublicIP) if err != nil { - fmt.Println("Error converting ip:", err) + fmt.Println("Error converting ip:\t", err) continue } - newService, found := indexInstance(config, cloudID, ip) + var found bool + config, found = indexInstance(config, cloudID, ip) if !found { - config.MTD.Services = append(config.MTD.Services, newService) - state.SaveConf(ConfigPath, config) + awsNewInstanceCounter++ + } else { + awsInactiveInstanceCounter-- } + awsInstanceCounter++ } + // TODO: Purge instances in config that are not found in the cloud + fmt.Printf("Found %d active AWS instances (%d newly added, %d inactive) (took %s)\n", + awsInstanceCounter, awsNewInstanceCounter, awsInactiveInstanceCounter, time.Since(t).Round(100*time.Millisecond).String()) + + return config } -func indexInstance(config state.Config, cloudID string, serviceIP netip.Addr) (state.Service, bool) { +func indexInstance(config state.Config, cloudID string, serviceIP netip.Addr) (state.Config, bool) { found := false - for _, service := range config.MTD.Services { + var foundUUID state.CustomUUID + for u, service := range config.MTD.Services { if service.CloudID == cloudID { found = true + foundUUID = u + break; } } - u := uuid.New() - newService := state.Service{ - ID: state.CustomUUID(u), - CloudID: cloudID, - ServiceIP: serviceIP} - return newService, found + + if !found { + fmt.Println("New instance found:\t", cloudID) + u := uuid.New() + config.MTD.Services[state.CustomUUID(u)] = state.Service{CloudID: cloudID, ServiceIP: serviceIP, Active: true, AdminEnabled: true} + state.SaveConf(ConfigPath, config) + + } else { + s := config.MTD.Services[foundUUID] + s.Active = true + config.MTD.Services[foundUUID] = s + state.SaveConf(ConfigPath, config) + } + return config, found } diff --git a/mtdaws/mtd.go b/mtdaws/mtd.go new file mode 100644 index 0000000..ab69354 --- /dev/null +++ b/mtdaws/mtd.go @@ -0,0 +1,142 @@ +package mtdaws + +import ( + "fmt" + "net/netip" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/google/uuid" + "github.com/thefeli73/polemos/state" +) + +// AWSUpdateService updates a specified service config to match a newly moved instance +func AWSUpdateService(config state.Config, region string, service state.CustomUUID, newInstanceID string) (state.Config) { + awsConfig := NewConfig(region, config.AWS.CredentialsPath) + svc := ec2.NewFromConfig(awsConfig) + instance, err := getInstanceDetailsFromString(svc, newInstanceID) + if err != nil { + fmt.Println("Error getting instance details:\t", err) + return config + } + + var publicAddr string + if instance.PublicIpAddress != nil { + publicAddr = aws.ToString(instance.PublicIpAddress) + } + formattedinstance := AwsInstance{ + InstanceID: aws.ToString(instance.InstanceId), + Region: region, + PublicIP: publicAddr, + PrivateIP: aws.ToString(instance.PrivateIpAddress), + } + cloudid := GetCloudID(formattedinstance) + serviceip := netip.MustParseAddr(publicAddr) + config.MTD.Services[service] = state.Service{CloudID: cloudid, ServiceIP: serviceip} + return config +} + + +// isInstanceRunning returns if an instance is running (true=running) +func isInstanceRunning(instance *types.Instance) bool { + return instance.State.Name == types.InstanceStateNameRunning +} + +// AWSMoveInstance moves a specified instance to a new availability region +func AWSMoveInstance(config state.Config) (state.Config) { + + // pseudorandom instance from all services for testing + var serviceUUID state.CustomUUID + var instance state.Service + for key, service := range config.MTD.Services { + serviceUUID = key + instance = service + if !instance.AdminEnabled {continue} + if !instance.Active {continue} + break + } + + fmt.Println("MTD move service:\t", uuid.UUID.String(uuid.UUID(serviceUUID))) + + region, instanceID := DecodeCloudID(instance.CloudID) + awsConfig := NewConfig(region, config.AWS.CredentialsPath) + svc := ec2.NewFromConfig(awsConfig) + + realInstance, err := getInstanceDetailsFromString(svc, instanceID) + if err != nil { + fmt.Println("Error getting instance details:\t", err) + return config + } + if !isInstanceRunning(realInstance) { + fmt.Println("Error, Instance is not running!") + return config + } + + //Create image + t := time.Now() + imageName, err := createImage(svc, instanceID) + if err != nil { + fmt.Println("Error creating image:\t", err) + return config + } + fmt.Printf("Created image:\t\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String()) + + // Wait for image + t = time.Now() + err = waitForImageReady(svc, imageName, 5*time.Minute) + if err != nil { + fmt.Println("Error waiting for image to be ready:\t", err) + return config + } + fmt.Printf("Image is ready:\t\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String()) + + // Launch new instance + t = time.Now() + newInstanceID, err := launchInstance(svc, realInstance, imageName, region) + if err != nil { + fmt.Println("Error launching instance:\t", err) + return config + } + fmt.Printf("Launched new instance:\t%s (took %s)\n", newInstanceID, time.Since(t).Round(100*time.Millisecond).String()) + + // Terminate old instance + t = time.Now() + err = terminateInstance(svc, instanceID) + if err != nil { + fmt.Println("Error terminating instance:\t", err) + return config + } + fmt.Printf("Killed old instance:\t%s (took %s)\n", instanceID, time.Since(t).Round(100*time.Millisecond).String()) + + // Deregister old image + t = time.Now() + image, err := describeImage(svc, imageName) + if err != nil { + fmt.Println("Error describing image:\t", err) + return config + } + err = deregisterImage(svc, imageName) + if err != nil { + fmt.Println("Error deregistering image:\t", err) + return config + } + fmt.Printf("Deregistered image:\t%s (took %s)\n", imageName, time.Since(t).Round(100*time.Millisecond).String()) + + // Delete old snapshot + t = time.Now() + if len(image.BlockDeviceMappings) > 0 { + snapshotID := aws.ToString(image.BlockDeviceMappings[0].Ebs.SnapshotId) + err = deleteSnapshot(svc, snapshotID) + if err != nil { + fmt.Println("Error deleting snapshot:\t", err) + return config + } + fmt.Printf("Deleted snapshot:\t%s (took %s)\n", snapshotID, time.Since(t).Round(100*time.Millisecond).String()) + } + + AWSUpdateService(config, region, serviceUUID, newInstanceID) + + return config +} diff --git a/mtdaws/utils.go b/mtdaws/utils.go index b59805b..f81850b 100644 --- a/mtdaws/utils.go +++ b/mtdaws/utils.go @@ -2,8 +2,12 @@ package mtdaws import ( "context" + "errors" "fmt" + "math/rand" "os" + "strings" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -36,6 +40,17 @@ func GetCloudID(instance AwsInstance) string { return "aws_" + instance.Region + "_" + instance.InstanceID } +// DecodeCloudID returns information to locate instance in aws +func DecodeCloudID(cloudID string) (string, string) { + split := strings.Split(cloudID, "_") + if len(split) != 3 { + panic(cloudID + " does not decode as AWS CloudID") + } + region := split[1] + instanceID := split[2] + return region, instanceID +} + // GetInstances scans all configured regions for instances and add them to services func GetInstances(config state.Config) []AwsInstance { awsInstances := []AwsInstance{} @@ -46,8 +61,6 @@ func GetInstances(config state.Config) []AwsInstance { fmt.Println("Error listing instances:", err) continue } - - //fmt.Println("Listing instances in region:", region) for _, instance := range instances { var publicAddr string if instance.PublicIpAddress != nil { @@ -63,25 +76,12 @@ func GetInstances(config state.Config) []AwsInstance { return awsInstances } -// PrintInstanceInfo prints info about a specific instance in a region -func PrintInstanceInfo(instance *types.Instance) { - fmt.Println("\tInstance ID:", aws.ToString(instance.InstanceId)) - fmt.Println("\t\tInstance Type:", string(instance.InstanceType)) - fmt.Println("\t\tAMI ID:", aws.ToString(instance.ImageId)) - fmt.Println("\t\tState:", string(instance.State.Name)) - fmt.Println("\t\tAvailability Zone:", aws.ToString(instance.Placement.AvailabilityZone)) - if instance.PublicIpAddress != nil { - fmt.Println("\t\tPublic IP Address:", aws.ToString(instance.PublicIpAddress)) - } - fmt.Println("\t\tPrivate IP Address:", aws.ToString(instance.PrivateIpAddress)) -} - // Instances returns all instances for a config i.e. a region -func Instances(config aws.Config) ([]*types.Instance, error) { +func Instances(config aws.Config) ([]types.Instance, error) { svc := ec2.NewFromConfig(config) input := &ec2.DescribeInstancesInput{} - var instances []*types.Instance + var instances []types.Instance paginator := ec2.NewDescribeInstancesPaginator(svc, input) @@ -93,10 +93,189 @@ func Instances(config aws.Config) ([]*types.Instance, error) { for _, reservation := range page.Reservations { for _, instance := range reservation.Instances { - instances = append(instances, &instance) + instances = append(instances, instance) } } } - return instances, nil } + +// createImage will create an AMI (amazon machine image) of a given instance +func createImage(svc *ec2.Client, instanceID string) (string, error) { + input := &ec2.CreateImageInput{ + InstanceId: aws.String(instanceID), + Name: aws.String(fmt.Sprintf("backup-%s-%d", instanceID, time.Now().Unix())), + Description: aws.String("Migration backup"), + NoReboot: aws.Bool(true), + } + + output, err := svc.CreateImage(context.TODO(), input) + if err != nil { + return "", err + } + + return aws.ToString(output.ImageId), nil +} + +// waitForImageReady polls every second to see if the image is ready +func waitForImageReady(svc *ec2.Client, imageID string, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + return errors.New("timed out waiting for image to be ready") + case <-time.After(1 * time.Second): + input := &ec2.DescribeImagesInput{ + ImageIds: []string{imageID}, + } + output, err := svc.DescribeImages(ctx, input) + if err != nil { + return err + } + + if len(output.Images) > 0 && output.Images[0].State == types.ImageStateAvailable { + return nil + } + } + } +} + +// launchInstance launches a instance IN RANDOM AVAILABILITY ZONE within the same region, based on an oldInstance and AMI (duplicating the instance) +func launchInstance(svc *ec2.Client, oldInstance *types.Instance, imageID string, region string) (string, error) { + securityGroupIds := make([]string, len(oldInstance.SecurityGroups)) + for i, sg := range oldInstance.SecurityGroups { + securityGroupIds[i] = aws.ToString(sg.GroupId) + } + // TODO: select random zone that is not the current one. + availabilityZone, err := getRandomDifferentAvailabilityZone(svc, oldInstance, region) + if err != nil { + return "", err + } + + input := &ec2.RunInstancesInput{ + ImageId: aws.String(imageID), + InstanceType: oldInstance.InstanceType, + MinCount: aws.Int32(1), + MaxCount: aws.Int32(1), + KeyName: oldInstance.KeyName, + SecurityGroupIds: securityGroupIds, + Placement: &types.Placement{ + AvailabilityZone: aws.String(availabilityZone), + }, + } + + output, err := svc.RunInstances(context.TODO(), input) + if err != nil { + return "", err + } + + // TODO: save/index config for the new instance + + return aws.ToString(output.Instances[0].InstanceId), nil +} + +// getRandomDifferentAvailabilityZone fetches all AZ from the same region as the instance and returns a random AZ that is not equal to the one used by the instance +func getRandomDifferentAvailabilityZone(svc *ec2.Client, instance *types.Instance, region string) (string, error) { + // Seed the random generator + rand.Seed(time.Now().UnixNano()) + + // Get the current availability zone of the instance + currentAZ := aws.ToString(instance.Placement.AvailabilityZone) + + // Describe availability zones in the region + input := &ec2.DescribeAvailabilityZonesInput{ + Filters: []types.Filter{ + { + Name: aws.String("region-name"), + Values: []string{region}, + }, + }, + } + + output, err := svc.DescribeAvailabilityZones(context.TODO(), input) + if err != nil { + return "", err + } + + // Filter out the current availability zone + availableAZs := []string{} + for _, az := range output.AvailabilityZones { + if aws.ToString(az.ZoneName) != currentAZ { + availableAZs = append(availableAZs, aws.ToString(az.ZoneName)) + } + } + + // If no other availability zones are available, return an error + if len(availableAZs) == 0 { + return "", errors.New("no other availability zones available") + } + + // Select a random availability zone from the remaining ones + randomIndex := rand.Intn(len(availableAZs)) + randomAZ := availableAZs[randomIndex] + return randomAZ, nil +} + + +// terminateInstance kills an instance by id +func terminateInstance(svc *ec2.Client, instanceID string) error { + input := &ec2.TerminateInstancesInput{ + InstanceIds: []string{instanceID}, + } + _, err := svc.TerminateInstances(context.TODO(), input) + return err +} + +// describeImage gets info about an image from string +func describeImage(svc *ec2.Client, imageID string) (*types.Image, error) { + input := &ec2.DescribeImagesInput{ + ImageIds: []string{imageID}, + } + + output, err := svc.DescribeImages(context.TODO(), input) + if err != nil { + return nil, err + } + + if len(output.Images) == 0 { + return nil, errors.New("image not found") + } + + return &output.Images[0], nil +} + +// deregisterImage deletes the AMI passed as string +func deregisterImage(svc *ec2.Client, imageID string) error { + input := &ec2.DeregisterImageInput{ + ImageId: aws.String(imageID), + } + + _, err := svc.DeregisterImage(context.TODO(), input) + return err +} + +// deleteSnapshot deletes the snapshot passed as string +func deleteSnapshot(svc *ec2.Client, snapshotID string) error { + input := &ec2.DeleteSnapshotInput{ + SnapshotId: aws.String(snapshotID), + } + + _, err := svc.DeleteSnapshot(context.TODO(), input) + return err +} + +// getInstanceDetailsFromString does what the name says +func getInstanceDetailsFromString(svc *ec2.Client, instanceID string) (*types.Instance, error) { + input := &ec2.DescribeInstancesInput{ + InstanceIds: []string{instanceID}, + } + + output, err := svc.DescribeInstances(context.TODO(), input) + if err != nil { + return nil, err + } + + return &output.Reservations[0].Instances[0], nil +} diff --git a/state/config.go b/state/config.go index a74c59e..67e3bee 100644 --- a/state/config.go +++ b/state/config.go @@ -17,13 +17,15 @@ type Config struct { } type mtdconf struct { - Services []Service `yaml:"services"` + Services map[CustomUUID]Service `yaml:"services"` + } // Service contains all necessary information about a service to identify it in the cloud as well as configuring a proxy for it type Service struct { - ID CustomUUID `yaml:"id"` CloudID string `yaml:"cloud_id"` + AdminEnabled bool `yaml:"admin_enabled"` + Active bool `yaml:"active"` EntryIP netip.Addr `yaml:"entry_ip"` EntryPort uint16 `yaml:"entry_port"` ServiceIP netip.Addr `yaml:"service_ip"`