/* GenPI.cpp
More challenging implementation of the Edinburgh HPC Masters test
(C) 2008 Niall Douglas
Created: 8th Mar 2008
*/

#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
#if (defined(_MSC_VER) && ((defined(_M_IX86) && _M_IX86>=400) || defined(_M_AMD64))) \
	|| (defined(__GNUC__) && defined(__SSE2__) && (defined(__i386__) || defined(__x86_64__))) \
	|| defined(__INTEL_COMPILER)
#include <emmintrin.h>
#define HAVE_SSE2
#endif

#ifdef _MSC_VER
#define ALIGNED16 __declspec(align(16))
#else
#define ALIGNED16 __attribute__ ((aligned (16)))
#endif


/* The general algorithm for n-th digit is :
s=0;
for(k=0; k<n; k++) s+=((pow(16, n-k) % (8*k+j))/(8*k+j));
for(k=n+1; ; k++) s+=pow(16, n-k)/(8*k+j);
s/=pow(16, n);
*/
namespace PICalcG
{
    struct ALIGNED16 D4
    {
        union
        {
            double v[4];
#ifdef HAVE_SSE2
            __m128d s[2];
#endif
        };
    };
    // We use a prebuilt table for speed
    static std::vector<double> pow16s;
    void preppow16(int digits) throw()
    {
        pow16s.resize(digits+101);
        for(int n=-100; n<=digits; n++)
            pow16s[n+100]=pow(16.0, n);
    }
    inline double pow16(int i) throw() { return pow16s[i+100]; }
    inline double pow16mod(unsigned e, double m) throw()
    {   // It's been a LONG, long time since I wrote code like this!
        if(1==m) return 0;
        int iter;
        {   // This is a highly illegal but fast way of portable BSR
	        union {
		        unsigned asInt[2];
		        double asDouble;
	        };

	        asDouble = e + 0.5;
	        iter = (asInt[1 /* WARNING!!! change to 0 if big endian */] >> 20) - 1023;
        }
        unsigned et=1<<iter; iter++;
        double result=1;
        for(int j=0; j<iter; j++)
        {
            if(e>=et)
            {
                result*=16;
                result-=(int)(result/m)*m;
                e-=et;
            }
            et>>=1;
            if(et>=1)
            {
                result*=result;
                result-=(int)(result/m)*m;
            }
        }
        return result;
    }
    template<int j> inline void series(double &o, int n) throw()
    {
        double s=0, t=0;
        int r, k=0;
        while(k<n)
        {
            r=8*k+j;
            s+=pow16mod(n-k, r)/r;
            s-=(int)s;
            k++;
        }
        /* One is supposed to iterate the following till t(+1) and t
        converge, but we can sacrifice precision for much improved
        scaling at higher N by capping the iterations */
        for(; k<n+100; k++)
        {
            t+=pow16(n-k)/(8*k+j);
            if(t<1e-17) break;
        }
        o=s+t;
    }

    void calc(std::vector<unsigned char> &output, int digits)
    {
        D4 t;
        double x;
        ulonglong hex, nexthex;
        digits=(digits+7)&~7;
        // It's a byte per two hex digits
        output.resize(4+digits/2);
        preppow16(digits);
        for(int n=0; n<digits; n+=6)
        {
            series<1>(t.v[0], n);
            series<4>(t.v[1], n);
            series<5>(t.v[2], n);
            series<6>(t.v[3], n);
            x=4*t.v[0]-2*t.v[1]-t.v[2]-t.v[3];
            x-=(int)x;
            hex=((ulonglong)(x*0x100000000000000ULL /* = 16^14 */))<<8;
            /* Accuracy check - top nibble should be identical */
            if(n && (nexthex>>60)!=(hex>>60))
            {
                printf("\nOut of precision!\n");
                abort();
            }
            output[(n/2)+0]=(unsigned char)((hex>>56) & 0xff);
            output[(n/2)+1]=(unsigned char)((hex>>48) & 0xff);
            output[(n/2)+2]=(unsigned char)((hex>>40) & 0xff);
            nexthex=hex<<24;
        }
        output.resize(digits/2);
    }

#ifdef HAVE_SSE2
    inline __m128d pow16modSSE2(unsigned _e, const double _m) throw()
    {   // Ok, it's for e+6 and e+0 same m
        if(1==_m) return _mm_setzero_pd();
        int iter;
        {   // This is a highly illegal but fast way of portable BSR
	        union {
		        unsigned asInt[2];
		        double asDouble;
	        };

	        asDouble = _e+6 + 0.5;
	        iter = (asInt[1 /* WARNING!!! change to 0 if big endian */] >> 20) - 1023;
        }
        unsigned e[2], et[2]; e[0]=_e+6; e[1]=_e; et[0]=et[1]=1<<iter;
        if(et[1]>e[1]) et[1]>>=1;
        __m128d result=_mm_set1_pd(1), m=_mm_set1_pd(_m);
        while(et[0])
        {
            bool do0=(et[0] && e[0]>=et[0]);
            bool do1=(et[1] && e[1]>=et[1]);
            // Try to do both at once if possible
            if(do0 || do1)
            {
                if(do0 && do1)
                {   // Ah good we can do both at once
                    result=_mm_mul_pd(result, _mm_set1_pd(16));
                    result=_mm_sub_pd(result, _mm_mul_pd(_mm_cvtepi32_pd(_mm_cvttpd_epi32(_mm_div_pd(result, m))), m));
                }
                else
                {
                    if(do1)     // Swap
                        result=_mm_shuffle_pd(result, result, _MM_SHUFFLE2(0, 1));
                    result=_mm_mul_sd(result, _mm_set1_pd(16));
                    result=_mm_sub_sd(result, _mm_mul_sd(_mm_cvtsi32_sd(_mm_setzero_pd(), _mm_cvttsd_si32(_mm_div_sd(result, m))), m));
                    if(do1)     // Swap
                        result=_mm_shuffle_pd(result, result, _MM_SHUFFLE2(0, 1));
                }
                if(do0) e[0]-=et[0];
                if(do1) e[1]-=et[1];
            }
            et[0]>>=1; et[1]>>=1;
            if(et[0])
            {
                if(et[1])
                {   // Once again, both at once
                    result=_mm_mul_pd(result, result);
                    result=_mm_sub_pd(result, _mm_mul_pd(_mm_cvtepi32_pd(_mm_cvttpd_epi32(_mm_div_pd(result, m))), m));
                }
                else
                {
                    result=_mm_mul_sd(result, result);
                    result=_mm_sub_sd(result, _mm_mul_sd(_mm_cvtsi32_sd(_mm_setzero_pd(), _mm_cvttsd_si32(_mm_div_sd(result, m))), m));
                }
            }
        }
        return result;
    }
    template<int j> inline void series(__m128d &o, int n) throw()
    {   /* Do n+6 and n+0. We take advantage of _sd() functions not touching
        the upper double value to continue n+6 without affecting n+0 */
        __m128d s=_mm_setzero_pd(), t=_mm_setzero_pd();
        int r, k=0;
        while(k<n)
        {
            r=8*k+j;
            //double p16m1=pow16mod(n-k, r), p16m2=pow16mod((n+6)-k, r);
            //__m128d foo=pow16modSSE2(n-k, r);
            //assert(p16m1==foo.m128d_f64[1] && p16m2==foo.m128d_f64[0]);
            s=_mm_add_pd(s, _mm_div_pd(pow16modSSE2(n-k, r), _mm_set1_pd(r)));
            s=_mm_sub_pd(s, _mm_cvtepi32_pd(_mm_cvttpd_epi32(s)));
            k++;
        }
        while(k<(n+6))
        {
            r=8*k+j;
            s=_mm_add_sd(s, _mm_div_sd(_mm_set_sd(pow16mod((n+6)-k, r)), _mm_set1_pd(r)));
            s=_mm_sub_sd(s, _mm_cvtsi32_sd(_mm_setzero_pd(), _mm_cvttsd_si32(s)));
            k++;
        }
        /* One is supposed to iterate the following till t(+1) and t
        converge, but we can sacrifice precision for much improved
        scaling at higher N by capping the iterations */
        for(k=n; k<n+100; k++)
        {
            t=_mm_add_pd(t, _mm_div_pd(_mm_set1_pd(pow16(n-k)), _mm_set_pd(8*k+j, 8*(k+6)+j)));
            if(3==_mm_movemask_pd(_mm_cmplt_pd(t, _mm_set1_pd(1e-17)))) break;
        }
        o=_mm_add_pd(s, t);
    }

    void calcSSE2(std::vector<unsigned char> &output, int digits)
    {   // Computes two digits in parallel. Lower double is n+6
        digits=(digits+7)&~7;
        // It's a byte per two hex digits
        output.resize(4+digits/2);
        preppow16(digits);
#ifdef _OPENMP
#pragma omp parallel for schedule(dynamic, 16)
#endif
        for(int n=0; n<digits; n+=12)
        {   // Calculate n+6 in parallel with n+0
            D4 t[2];
            double x[2];
            ulonglong hex, nexthex;
            series<1>(t[0].s[0], n);
            series<4>(t[0].s[1], n);
            series<5>(t[1].s[0], n);
            series<6>(t[1].s[1], n);
            x[1]=4*t[0].v[0]-2*t[0].v[2]-t[1].v[0]-t[1].v[2];
            x[1]-=(int)x[1];
            x[0]=4*t[0].v[1]-2*t[0].v[3]-t[1].v[1]-t[1].v[3];
            x[0]-=(int)x[0];
            hex=((ulonglong)(x[0]*0x100000000000000ULL /* = 16^14 */))<<8;
            output[(n/2)+0]=(unsigned char)((hex>>56) & 0xff);
            output[(n/2)+1]=(unsigned char)((hex>>48) & 0xff);
            output[(n/2)+2]=(unsigned char)((hex>>40) & 0xff);
            nexthex=hex<<24;
            hex=((ulonglong)(x[1]*0x100000000000000ULL /* = 16^14 */))<<8;
            /* Accuracy check - top nibble should be identical */
            if(n && (nexthex>>60)!=(hex>>60))
            {
                printf("\nOut of precision!\n");
                abort();
            }
            output[(n/2)+3]=(unsigned char)((hex>>56) & 0xff);
            output[(n/2)+4]=(unsigned char)((hex>>48) & 0xff);
            output[(n/2)+5]=(unsigned char)((hex>>40) & 0xff);
        }
        output.resize(digits/2);
    }
#endif

}
