{-# LANGUAGE ParallelListComp #-}
module Math.Polynomial.Bernstein
    ( bernstein
    , evalBernstein
    , bernsteinFit
    , evalBernsteinSeries
    , deCasteljau
    , splitBernsteinSeries
    ) where

import Math.Polynomial
import Data.List

-- |The Bernstein basis polynomials.  The @n@th inner list is a basis for 
-- the polynomials of order @n@ or lower.  The @n@th basis consists of @n@
-- polynomials of order @n@ which sum to @1@, and have roots of varying 
-- multiplicities at @0@ and @1@.
bernstein :: [[Poly Integer]]
bernstein :: [[Poly Integer]]
bernstein = 
    [ [ Integer -> Poly Integer -> Poly Integer
forall a. (Num a, Eq a) => a -> Poly a -> Poly a
scalePoly Integer
nCv Poly Integer
p Poly Integer -> Poly Integer -> Poly Integer
forall a. (Num a, Eq a) => Poly a -> Poly a -> Poly a
`multPoly` Poly Integer
q
      | Poly Integer
q <- [Poly Integer] -> [Poly Integer]
forall a. [a] -> [a]
reverse [Poly Integer]
qs
      | Poly Integer
p <- [Poly Integer]
ps
      | Integer
nCv  <- [Integer]
bico
      ]
    | [Poly Integer]
ps <- [[Poly Integer]] -> [[Poly Integer]]
forall a. [a] -> [a]
tail ([[Poly Integer]] -> [[Poly Integer]])
-> [[Poly Integer]] -> [[Poly Integer]]
forall a b. (a -> b) -> a -> b
$ [Poly Integer] -> [[Poly Integer]]
forall a. [a] -> [[a]]
inits [Endianness -> [Integer] -> Poly Integer
forall a. (Num a, Eq a) => Endianness -> [a] -> Poly a
poly Endianness
BE (1 Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: [Integer]
zs) | [Integer]
zs <- [Integer] -> [[Integer]]
forall a. [a] -> [[a]]
inits (Integer -> [Integer]
forall a. a -> [a]
repeat 0)]
    | [Poly Integer]
qs <- [[Poly Integer]] -> [[Poly Integer]]
forall a. [a] -> [a]
tail ([[Poly Integer]] -> [[Poly Integer]])
-> [[Poly Integer]] -> [[Poly Integer]]
forall a b. (a -> b) -> a -> b
$ [Poly Integer] -> [[Poly Integer]]
forall a. [a] -> [[a]]
inits ((Poly Integer -> Poly Integer) -> Poly Integer -> [Poly Integer]
forall a. (a -> a) -> a -> [a]
iterate (Poly Integer -> Poly Integer -> Poly Integer
forall a. (Num a, Eq a) => Poly a -> Poly a -> Poly a
multPoly (Endianness -> [Integer] -> Poly Integer
forall a. (Num a, Eq a) => Endianness -> [a] -> Poly a
poly Endianness
LE [1,-1])) Poly Integer
forall a. (Num a, Eq a) => Poly a
one)
    | [Integer]
bico <- [[Integer]]
ptri
    ]
    where
        -- pascal's triangle
        ptri :: [[Integer]]
ptri = [1] [Integer] -> [[Integer]] -> [[Integer]]
forall a. a -> [a] -> [a]
: [ 1 Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: (Integer -> Integer -> Integer)
-> [Integer] -> [Integer] -> [Integer]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+) [Integer]
row ([Integer] -> [Integer]
forall a. [a] -> [a]
tail [Integer]
row) [Integer] -> [Integer] -> [Integer]
forall a. [a] -> [a] -> [a]
++ [1] | [Integer]
row <- [[Integer]]
ptri]

-- |@evalBernstein n v x@ evaluates the @v@'th Bernstein polynomial of order @n@
-- at the point @x@.
evalBernstein :: (Integral a, Num b) => a -> a -> b -> b
evalBernstein :: a -> a -> b -> b
evalBernstein n :: a
n v :: a
v t :: b
t
    | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< 0 Bool -> Bool -> Bool
|| a
v a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
n    = 0
    | Bool
otherwise         = Integer -> b
forall a. Num a => Integer -> a
fromInteger Integer
nCv b -> b -> b
forall a. Num a => a -> a -> a
* b
tb -> a -> b
forall a b. (Num a, Integral b) => a -> b -> a
^a
v b -> b -> b
forall a. Num a => a -> a -> a
* (1b -> b -> b
forall a. Num a => a -> a -> a
-b
t)b -> a -> b
forall a b. (Num a, Integral b) => a -> b -> a
^(a
na -> a -> a
forall a. Num a => a -> a -> a
-a
v)
    where
        n' :: Integer
n' = a -> Integer
forall a. Integral a => a -> Integer
toInteger a
n
        v' :: Integer
v' = a -> Integer
forall a. Integral a => a -> Integer
toInteger a
v
        nCv :: Integer
nCv = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [1..Integer
n'] Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` ([Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [1..Integer
v'] Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [1..Integer
n'Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
v'])

-- |@bernsteinFit n f@: Approximate a function @f@ as a linear combination of
-- Bernstein polynomials of order @n@.  This approximation converges slowly
-- but uniformly to @f@ on the interval [0,1].
bernsteinFit :: (Fractional b, Integral a) => a -> (b -> b) -> [b]
bernsteinFit :: a -> (b -> b) -> [b]
bernsteinFit n :: a
n f :: b -> b
f = [b -> b
f (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
v b -> b -> b
forall a. Fractional a => a -> a -> a
/ a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n) | a
v <- [0..a
n]]

-- |Evaluate a polynomial given as a list of @n@ coefficients for the @n@th
-- Bernstein basis.  Roughly:
-- 
-- > evalBernsteinSeries cs = sum (zipWith scalePoly cs (bernstein !! (length cs - 1)))
evalBernsteinSeries :: Num a => [a] -> a -> a
evalBernsteinSeries :: [a] -> a -> a
evalBernsteinSeries [] = a -> a -> a
forall a b. a -> b -> a
const 0
evalBernsteinSeries cs :: [a]
cs = [a] -> a
forall a. [a] -> a
head ([a] -> a) -> (a -> [a]) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> [a]
forall a. [a] -> a
last ([[a]] -> [a]) -> (a -> [[a]]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> a -> [[a]]
forall a. Num a => [a] -> a -> [[a]]
deCasteljau [a]
cs

-- |de Casteljau's algorithm, returning the whole tableau.  Used both for
-- evaluating and splitting polynomials in Bernstein form.
deCasteljau :: Num a => [a] -> a -> [[a]]
deCasteljau :: [a] -> a -> [[a]]
deCasteljau [] _ = []
deCasteljau cs :: [a]
cs t :: a
t = [a]
cs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [a] -> a -> [[a]]
forall a. Num a => [a] -> a -> [[a]]
deCasteljau ((a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (a -> a -> a -> a
forall a. Num a => a -> a -> a -> a
interp a
t) [a]
cs ([a] -> [a]
forall a. [a] -> [a]
tail [a]
cs)) a
t
    where interp :: a -> a -> a -> a
interp t :: a
t x0 :: a
x0 x1 :: a
x1 = (1a -> a -> a
forall a. Num a => a -> a -> a
-a
t)a -> a -> a
forall a. Num a => a -> a -> a
*a
x0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
ta -> a -> a
forall a. Num a => a -> a -> a
*a
x1

-- |Given a polynomial in Bernstein form (that is, a list of coefficients
-- for a basis set from 'bernstein', such as is returned by 'bernsteinFit')
-- and a parameter value @x@, split the polynomial into two halves, mapping
-- @[0,x]@ and @[x,1]@ respectively onto @[0,1]@.
--
-- A typical use for this operation would be to split a Bezier curve 
-- (inserting a new knot at @x@).
splitBernsteinSeries :: Num a => [a] -> a -> ([a], [a])
splitBernsteinSeries :: [a] -> a -> ([a], [a])
splitBernsteinSeries cs :: [a]
cs t :: a
t = (([a] -> a) -> [[a]] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map [a] -> a
forall a. [a] -> a
head [[a]]
betas, ([a] -> a) -> [[a]] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map [a] -> a
forall a. [a] -> a
last ([[a]] -> [[a]]
forall a. [a] -> [a]
reverse [[a]]
betas))
    where
        betas :: [[a]]
betas = [a] -> a -> [[a]]
forall a. Num a => [a] -> a -> [[a]]
deCasteljau [a]
cs a
t