# -*- coding: utf-8 -*-
"""
Created on Sat Jan  9 11:00:23 2021

@author: larsw
"""

import Particle as part
import PSO_Lib as pl
import random
import pandas as pd





    

def pso (n, pop, sample, c1=1/3, c2=1/3, c3=1/3, abort = 50):
    test_array= list()
    global_attractor = list()
    global_attractor_cost = float('inf')
    
    ##### Initialize #####
    
    population = list()
    
    # create pop-many particle
    
    for i in range(pop):
        new_particle = part.Particle(n)
        population.append(new_particle)
        for p in population:
            p.local_attractor=p.position
            p.local_attractor_cost = pl.cost_function(sample,p.local_attractor)
        
            
            if(p.local_attractor_cost < global_attractor_cost):
                global_attractor = p.position
                global_attractor_cost = p.local_attractor_cost
            
                
     
    # Main-Loop    
    for i in range(abort):
        
        
        
        # Calculate new velocity and position for each particle
        for p in population:
            
            p.velocity = pl.add_vel(
                    pl.multipl_coeff_vel(c1,p.velocity),
                    pl.add_vel(
                        pl.multipl_coeff_vel(
                            c2,pl.sub_pos(p.local_attractor,p.position)),
                        pl.multipl_coeff_vel(
                            c3,pl.sub_pos(global_attractor,p.position))
                    ))
            
            p.position = pl.add_pos_vel(p.position,p.velocity)
            p.cost = pl.cost_function(sample,p.position)
           
            
            if(p.cost<p.local_attractor_cost):
                p.local_attractor=p.position
                p.local_attractor_cost=p.cost
         
            if(p.cost<global_attractor_cost):
                
                print(global_attractor)
                print(global_attractor_cost)
                

               
                global_attractor=p.position
                global_attractor_cost = p.cost
        print(global_attractor)
        print(global_attractor_cost)
        print("\n")
     
        
            
    return test_array