In the previous post I discussed an algorithm for calculating the square root of a number. This short post will show an implementation in Haskell.
I did one implementation using the State Monad but to be quite honest it was an impenetrable mess! I’m sure part of that was down to my embryonic Haskell skills but perhaps the State Monad was a bad fit for this problem? The implementation below uses a fold over a list of digit pairs and the start state for the fold is based on ‘the largest integer squared’ with subsequent calculations using ‘doubling’ – as per the algorithm. Most of the other functions are really helpers for formatting, digit pair extraction, decimal point location etc. I think the code is fairly readable yet I still feel it is a little too complex and would welcome any comments from more experienced Haskellers.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
-- import Data.List.Split import Data.Char import Numeric -- The entry point. root :: Float -> Float root n = f where ((f, _):_) = readFloat (lhs ++ "." ++ rhs) (_, rootDigits) = rootFold n (lhs, rhs) = splitAt (dpLocation n) rootDigits -- -- fold with the initial value based on intSquareRoot function -- and subsequent calculations based on doubling and the biggestN function. rootFold :: Float -> (Integer, String) rootFold n = foldr calculate (makeStartVal p1 p2) pairs where (p1:p2:pairs) = digitList n makeStartVal :: Integer -> Integer -> (Integer, String) makeStartVal p1 p2 = res where rt = intSquareRoot p1 res = (p2 + (p1 - rt * rt) * 100 , show rt) calculate :: Integer -> (Integer, String) -> (Integer, String) calculate p (n, res) = next where (toAppend, remain) = biggestN (2 * read res) n -- bring down the next pair and accumulate the result next = (remain * 100 + p, res ++ show toAppend) -- Where should decimal point be? dpLocation :: Float -> Int dpLocation n = if (even len) then len `div` 2 else (len + 1) `div` 2 where [left, _] = splitOn "." . show $ n len = length left -- helper for formatting formatFloatN numOfDecimals floatNum = showFFloat (Just numOfDecimals) floatNum "" showFlt = formatFloatN 16 -- Takes float and makes list of 'paired' integers digitList :: Float -> [Integer] digitList n = res where [l, r] = splitOn "." . showFlt $ n res = map read $ (pairs . pad $ l) ++ (pairs . pad $ r) where pairs [] = [] pairs xs = let (ys, zs) = splitAt 2 xs in ys : pairs zs pad xs | odd . length $ xs = "0" ++ xs | otherwise = xs -- eg largest number N such that 4N x N <= 161 -- and biggestN 4 161 = (3, 32) -- biggestN :: Integer -> Integer -> (Integer, Integer) biggestN = get 0 where get n x y | (x*10 + n) * n > y = (n-1, y - (x*10 + n - 1)*(n - 1)) | (x*10 + n) * n == y = (n , y - (x*10 + n) * n ) | otherwise = get (n + 1) x y -- gives the largest int whose square is <= n intSquareRoot :: Integer -> Integer intSquareRoot n = root 0 where root i | i*i <= n = root (i + 1) | otherwise = i - 1 |
And some examples of using it…
1 2 |
λ-> map root [1, 2, 4, 33, 100, 101, 123, 144, 625, 123456.234] [1.0,1.4142135,2.0,5.7445626,10.0,10.049875,11.090536,12.0,25.0,351.28336] |
and if we compose map (^2) with map root we should get back exactly what we started with in the case of perfect squares and something more or less the same for others.
1 2 3 |
λ-> map (^2) . map root $ [1, 2, 4, 33, 100, 101, 123, 144, 625, 123456.234] [1.0,1.9999999,4.0,33.0,100.0,100.99999,122.99999,144.0,625.0,123399.99] *Main |
Thanks for reading…!