SICP Chapter #02 Examples in Lua
-- Utility functions
function table_to_string(T)
if type(T) ~= 'table' then return tostring(T) end
local rt = "{"
local first = true
for key,value in pairs(T) do
if (not first) then rt = rt .. ',' end
first = false
if type(value) == 'table' and type(key) == 'number' then
rt = rt .. table_to_string(value)
elseif type(value) == 'table' then
rt = rt .. key .. "=" .. table_to_string(value)
elseif type(key) == 'number' then
rt = rt .. value
else
rt = rt .. key .. "=" .. value
end
end
rt = rt .. "}"
return rt
end
function printx(T)
if type(T) == 'table' then
print (table_to_string(T))
else
print (T)
end
end
function table_copy(T)
local rt = {}
for key,value in pairs(T) do
rt[key] = value
end
return rt
end
function slice(T, first, last)
local rt = {}
if last == nil then last = #T end
for i = first, last, 1 do
rt[#rt+1] = T[i]
end
return rt
end
-- Functions defined in previous chapters
function gcd(a, b)
if (b == 0) then
return a
else
return gcd(b, math.mod(a, b))
end
end
function fib(n)
if (n == 0) then
return 0
elseif (n == 1) then
return 1
else
return fib(n - 1) + fib(n - 2)
end
end
function identity(x) return x end
-- 2 Building Abstractions with Data
function linear_combination(a, b, x, y)
return a*x + b*y
end
function mul(a, b) return a * b end
function linear_combination(a, b, x, y)
return mul(a, x) + mul(b, y)
end
-- 2.1.1 Introduction to Data Abstraction - Example: Arithmetic Operations for Rational Numbers
-- Literal Translation --
function make_rat(n, d) return {n, d} end
function numer(T) return T[1] end
function denom(T) return T[2] end
function add_rat(x, y)
return make_rat(numer(x)*denom(y) + numer(y)*denom(x), denom(x)*denom(y))
end
function sub_rat(x, y)
return make_rat(numer(x)*denom(y) - numer(y)*denom(x), denom(x)*denom(y))
end
function mul_rat(x, y)
return make_rat(numer(x)*numer(y), denom(x)*denom(y))
end
function div_rat(x, y)
return make_rat(numer(x)*denom(y), denom(x)*numer(y))
end
function equal_rat(x, y)
return numer(x)*denom(y) == numer(y)*denom(x)
end
function cons(x, y) return {x, y} end
function car(T) return T[1] end
function cdr(T) return T[2] end
x = cons(1, 2)
printx (car(x))
printx (cdr(x))
x = cons(1, 2)
y = cons(3, 4)
z = cons(x, y)
printx (car(car(z)))
printx (car(cdr(z)))
printx (z)
--footnote -- alternative definitions
make_rat = cons
numer = car
denom = cdr
x = {1, 2}
y = {3, 4}
printx (numer(x))
printx (denom(x))
function compose(f, g) return function(x) return f(g(x)) end end
function print_rat(x)
printx (string.format("%d/%d", numer(x), denom(x)))
end
one_half = make_rat(1,2)
print_rat(one_half)
one_third = make_rat(1, 3)
print_rat(add_rat(one_half, one_third))
print_rat(mul_rat(one_half, one_third))
print_rat(add_rat(one_third, one_third))
-- reducing to lowest terms in constructor
function make_rat(n, d)
local g = gcd(n, d)
return {n / g, d / g}
end
function add_rat(x, y)
return make_rat(numer(x)*denom(y) + numer(y)*denom(x), denom(x)*denom(y))
end
print_rat(add_rat(one_third, one_third))
-- end Literal Translation --
-- Record Translation --
function make_rat(n, d) return {numer = n, denom = d} end
function add_rat(x, y)
return make_rat(x.numer*y.denom + y.numer*x.denom, x.denom*y.denom)
end
function sub_rat(x, y)
return make_rat(x.numer*y.denom - y.numer*x.denom, x.denom*y.denom)
end
function mul_rat(x, y)
return make_rat(x.numer*y.numer, x.denom*y.denom)
end
function div_rat(x, y)
return make_rat(x.numer*y.denom, x.denom*y.numer)
end
function equal_rat(x, y)
return x.numer*y.denom == y.numer*x.denom
end
function print_rat(x)
printx (string.format("%d/%d", x.numer, x.denom))
end
one_half = make_rat(1,2)
print_rat(one_half)
one_third = make_rat(1, 3)
print_rat(add_rat(one_half, one_third))
print_rat(mul_rat(one_half, one_third))
print_rat(add_rat(one_third, one_third))
-- reducing to lowest terms in constructor
function make_rat(n, d)
local g = gcd(n, d)
return {numer = n / g, denom = d / g}
end
function add_rat(x, y)
return make_rat((x.numer * y.denom) + (y.numer * x.denom), x.denom * y.denom)
end
print_rat(add_rat(one_third, one_third))
-- end Record Translation --
-- Object Translation --
function Rational(n, d)
return {
numer = n,
denom = d,
add_rat = function(self, y)
return Rational(self.numer*y.denom + y.numer*self.denom, self.denom*y.denom)
end,
sub_rat = function(self, y)
return Rational(self.numer*y.denom - y.numer*self.denom, self.denom*y.denom)
end,
mul_rat = function(self, y)
return Rational(self.numer*y.numer, self.denom*y.denom)
end,
div_rat = function(self, y)
return Rational(self.numer*y.denom, self.denom*y.numer)
end,
equal_rat = function(self, y)
return self.numer*y.denom == y.numer * self.denom
end,
print_rat = function(self)
printx (string.format("%d/%d", self.numer, self.denom))
end
}
end
one_half = Rational(1, 2)
one_half:print_rat()
one_third = Rational(1, 3)
one_half:add_rat(one_third):print_rat()
one_half:mul_rat(one_third):print_rat()
one_third:add_rat(one_third):print_rat()
-- reducing to lowest terms in constructor
function Rational(n, d)
local g = gcd(n, d)
return {
numer = n / g,
denom = d / g,
add_rat = function(self, y)
return Rational(self.numer*y.denom + y.numer*self.denom, self.denom*y.denom)
end,
sub_rat = function(self, y)
return Rational(self.numer*y.denom - y.numer*self.denom, self.denom*y.denom)
end,
mul_rat = function(self, y)
return Rational(self.numer*y.numer, self.denom*y.denom)
end,
div_rat = function(self, y)
return Rational(self.numer*y.denom, self.denom*y.numer)
end,
equal_rat = function(self, y)
return self.numer*y.denom == y.numer*self.denom
end,
print_rat = function(self)
printx (string.format("%d/%d", self.numer, self.denom))
end
}
end
one_third = Rational(1, 3)
one_third:print_rat()
one_third:add_rat(one_third):print_rat()
-- end Object Translation --
-- Exercise 2.1
function make_rat(n, d)
if (d < 0 and n < 0) or n < 0 then
return {d * -1, n * -1}
else
return {d, n}
end
end
-- 2.1.2 Introduction to Data Abstraction - Abstraction barriers
-- Record Translation --
-- reducing to lowest terms in selectors
function make_rat(n, d) return {numer = n, denom = d} end
function numer(x)
local g = gcd(x.numer, x.denom)
return x.numer / g
end
function denom(x)
local g = gcd(x.numer, x.denom)
return x.denom / g
end
-- end Record Translation --
-- Object Translation --
-- reducing to lowest terms in selectors
function Rational(n, d)
return {
numer = function (self)
local g = gcd(n, d)
return n / g
end,
denom = function (self)
local g = gcd(n, d)
return d / g
end,
add_rat = function(self, y)
return Rational(
(self:numer() * y:denom()) + (y:numer() * self:denom()),
self:denom() * y:denom())
end,
sub_rat = function(self, y)
return Rational(
(self:numer() * y:denom()) - (y:numer() * self:denom()),
self:denom() * y:denom())
end,
mul_rat = function(self, y)
return Rational(self:numer() * y:numer(), self:denom() * y:denom())
end,
div_rat = function(self, y)
return Rational(self:numer() * y:denom(), self:denom() * y:numer())
end,
equal_rat = function(self, y)
return self:numer() * y:denom() == y:numer() * self:denom()
end,
print_rat = function(self)
printx (string.format("%d/%d", self:numer(), self:denom()))
end
}
end
-- end Object Translation --
-- Exercise 2.2
function make_point(x, y) return {x = x, y = y} end
function make_segment(start_segment, end_segment)
return {start_segment = start_segment, end_segment = end_segment}
end
function midpoint_segment(segment)
local s = segment.start_segment
local e = segment.end_segment
return make_point((s.x + e.x) / 2, (s.y + e.y) / 2)
end
function print_point(p)
printx (string.format("(%d,%d)", p.x, p.y))
end
-- Exercise 2.3
function square(x) return x * x end
function length_segment(segment)
local s = segment.start_segment
local e = segment.end_segment
return math.sqrt(square(e.x - s.x) + square(e.y - s.y))
end
-- Constructors create type tagged using
function make_rectangle_axy(anchor, xlen, ylen)
return {anchor=anchor, xlen=xlen, ylen=ylen}
end
function make_rectangle_seg(start_segment, end_segment)
return {start_segment=start_segment, end_segment=end_segment}
end
-- 'length_rectangle' and 'width_rectangle' act as an abstraction barrier for higher-level
-- procedures because 'rectangle' implementation details are buried here, and should the
-- implementation change, only these procedures will need to be altered to support the change
function length_rectangle(rect)
if rect.anchor ~= nil then
return 0 -- Compute length ...
elseif rect.start_segment ~= nil then
return 0 -- Compute length ...
end
end
function width_rectangle(rect)
-- As per 'length_rectangle' except that rectangle width is returned ...
return 0
end
-- High-level procedures are quarantined from representation / implementation details
function area_rectangle(rect)
return length_rectangle(rect) * width_rectangle(rect)
end
function perimeter_rectangle(rect)
return length_rectangle(rect) * 2 + width_rectangle(rect) * 2
end
-- 2.1.3 Introduction to Data Abstraction - What is meant by data?
function cons(x, y)
local function dispatch(m)
if (m == 1) then
return x
elseif (m == 2) then
return y
else
error("Argument not 1 or 2 -- CONS " .. tostring(m))
end
end
return dispatch
end
function car(z) return z(1) end
function cdr(z) return z(2) end
-- Exercise 2.4
function cons(x, y)
return function(m) return m(x, y) end
end
function car(z)
return z(function(p, q) return p end)
end
function cdr(z)
return z(function(p, q) return q end)
end
-- Exercise 2.5
function cons(x, y)
return math.pow(2, x * math.pow(3, y))
end
function car(z)
if math.mod(z, 2) == 0 then
return car((z / 2) + 1)
else
return 0
end
end
function cdr(z)
if math.mod(z, 3) == 0 then
return cdr((z / 3) + 1)
else
return 0
end
end
-- Exercise 2.6
zero = function(f) return function(x) return x end end
function add1(n)
return function(f)
return function(x)
return f(n(f)(x))
end
end
end
-- 2.1.4 Introduction to Data Abstraction - Extended Exercise: Interval Arithmetic
-- Record Translation --
function make_interval(a, b) return {lower_bound = a, upper_bound = b} end
function add_interval(x, y)
return make_interval(x.lower_bound + y.lower_bound, x.upper_bound + y.upper_bound)
end
function mul_interval(x, y)
local p1 = x.lower_bound * y.lower_bound
local p2 = x.lower_bound * y.upper_bound
local p3 = x.upper_bound * y.lower_bound
local p4 = x.upper_bound * y.upper_bound
return make_interval(
math.min(math.min(p1, p2), math.min(p3, p4)),
math.max(math.max(p1, p2), math.max(p3, p4)))
end
function div_interval(x, y)
local z = make_interval(1 / y.upper_bound, 1 / y.lower_bound)
return mul_interval(x, z)
end
function make_center_width(c, w)
return make_interval(c-w, c+w)
end
function center(interval)
return (interval.lower_bound + interval.upper_bound) / 2
end
function width(interval)
return (interval.upper_bound - interval.lower_bound) / 2
end
-- parallel resistors
function par1(r1, r2)
return div_interval(mul_interval(r1, r2), add_interval(r1, r2))
end
function par2(r1, r2)
local one = make_interval(1, 1)
return div_interval(one,
add_interval(div_interval(one, r1),
div_interval(one, r2)))
end
-- end Record Translation --
-- Object Translation --
function Interval(a, b)
return {
lower_bound = a,
upper_bound = b,
add_interval = function(self, y)
return Interval(self.lower_bound + y.lower_bound, self.upper_bound + y.upper_bound)
end,
mul_interval = function(self, y)
local p1 = self.lower_bound * y.lower_bound
local p2 = self.lower_bound * y.upper_bound
local p3 = self.upper_bound * y.lower_bound
local p4 = self.upper_bound * y.upper_bound
return Interval(
math.min(math.min(p1, p2), math.min(p3, p4)),
math.max(math.max(p1, p2), math.max(p3, p4)))
end,
div_interval = function(self, y)
local z = Interval(1 / y.upper_bound, 1 / y.lower_bound)
return self:mul_interval(z)
end,
make_center_width = function(self, c, w)
return Interval(c-w, c+w)
end,
center = function(self)
return (self.lower_bound + self.upper_bound) / 2
end,
width = function(self)
return (self.upper_bound - self.lower_bound) / 2
end
}
end
interval = Interval(10, 20)
-- parallel resistors
function par1(r1, r2)
return r1:mul_interval(r2):div_interval(r1:add_interval(r2))
end
function par2(r1, r2)
local one = Interval(1, 1)
return one:div_interval(one:div_interval(r1):add_interval(one:div_interval(r2)))
end
-- end Object Translation --
-- Exercise 2.8
function sub_interval(x, y)
return make_interval(x.lower_bound - y.lower_bound, x.upper_bound - y.upper_bound)
end
-- Exercise 2.9
i = make_interval(5, 10)
j = make_interval(15, 25)
-- width of the sum (or difference) of two intervals *is* a function only of the widths of
-- the intervals being added (or subtracted)
printx (width(add_interval(i, j)), width(i) + width(j))
printx (width(sub_interval(i, j)), width(i) + width(j))
-- width of the product (or quotient) of two intervals *is not* a function only of the widths
-- of the intervals being multiplied (or divided)
printx (width(mul_interval(i, j)), width(i) + width(j))
printx (width(div_interval(i, j)), width(i) + width(j))
-- Exercise 2.10
function is_zero_interval(i)
return i.lower_bound == 0 or i.upper_bound == 0
end
function div_interval_zero_check(x, y)
if is_zero_interval(y) then
error("Zero interval divisor")
else
return div_interval(x, y)
end
end
-- Exercise 2.12
function make_center_percent(c, p)
return make_center_width(c, p * c / 100)
end
function percent(i)
return width(i) / center(i) * 100
end
-- 2.2.1 Hierarchical Data and the Closure Property - Representing Sequences
one_through_four = {1, 2, 3, 4}
printx (one_through_four)
printx (one_through_four[1])
printx (slice(one_through_four, 2))
printx (one_through_four[2])
table.insert(table_copy(one_through_four), 10) -- note: Lua insert modifies state
printx (one_through_four)
function list_ref(items, n)
return items[n]
end
squares = {1, 4, 9, 16, 25}
printx (list_ref(squares, 4)) -- note: Lua index begins at 1
function length(items)
return #items
end
odds = {1, 3, 5, 7}
printx (#odds)
function append(T1, T2)
local rt = table_copy(T1)
for key,value in pairs(T2) do
rt[#rt+1] = value
end
return rt
end
printx(append(squares, odds))
printx(append(odds, squares))
-- Mapping over lists
function scale_list(factor, T)
local rt = {}
for key,value in pairs(T) do
rt[key] = value * factor
end
return rt
end
printx (scale_list(10, {1, 2, 3, 4, 5}))
-- uncurried version of map
function map(proc, T)
local rt = {}
for key,value in pairs(T) do
rt[key] = proc(value)
end
return rt
end
printx (map(math.abs, {-10, 2.5, -11.6, 17}))
printx (map(function(x) return x * x end, {1, 2, 3, 4}))
function scale_list(factor, items)
return map(function(x) return x * factor end, items)
end
-- curried version map
function map(proc)
local function map_lambda(T)
local rt = {}
for key,value in pairs(T) do
rt[key] = proc(value)
end
return rt
end
return map_lambda
end
printx (map (math.abs) ({-10, 2.5, -11.6, 17}))
printx (map (function(x) return x * x end) ({1, 2, 3, 4}))
function scale_list(factor, items)
return map (function(x) return x * factor end) (items)
end
-- Exercise 2.17
function last_pair(T)
return { T[#T] }
end
printx(last_pair({23, 72, 149, 34}))
-- Exercise 2.18
function reverse(T)
local rt = {}
for i = #T, 1, -1 do
rt[#rt+1] = T[i]
end
return rt
end
printx(reverse({1, 4, 9, 16, 25}))
-- Exercise 2.19
function no_more(T)
return #T == 0
end
function except_first_denomination(coin_values)
local tail = slice(coin_values, 2)
return tail
end
function first_denomination(coin_values)
local head = coin_values[1]
return head
end
function cc(amount, coin_values)
if (amount == 0) then
return 1
elseif ((amount < 0) or (no_more(coin_values))) then
return 0
else
return (cc(amount, except_first_denomination(coin_values)) +
cc(amount - first_denomination(coin_values), coin_values))
end
end
us_coins = {50, 25, 10, 5, 1}
uk_coins = {100, 50, 20, 10, 5, 2, 1, 0.5}
printx (cc(100, us_coins))
-- works but takes a long time based on inefficiency above (slice)
-- printx (cc(100, uk_coins))
-- Exercise 2.20
function filter(predicate, T)
local rt = {}
for key,value in pairs(T) do
if predicate(value) then
rt[#rt+1] = value
end
end
return rt
end
function is_odd(n) return math.mod(n, 2) == 1 end
function is_even(n) return not(is_odd(n)) end
function same_parity(T)
local head = T[1]
local tail = slice(T, 2)
local predicate = is_odd(head) and is_odd or is_even
return filter(predicate, tail)
end
printx (same_parity({1, 2, 3, 4, 5, 6, 7}))
printx (same_parity({2, 3, 4, 5, 6, 7}))
-- Exercise 2.21
function square_list(T)
rt = {}
for key,value in pairs(T) do
rt[key] = value * value
end
return rt
end
printx (square_list({1, 2, 3, 4}))
function square_list(T)
return map (function(X) return X * X end) (T)
end
printx (square_list({1, 2, 3, 4}))
-- Exercise 2.23
function for_each(f, T)
for key,value in pairs(T) do
f(value)
end
end
-- 2.2.2 Hierarchical Data and the Closure Property - Hierarchical Structures
function count_leaves(T)
local n = 0
for key,value in pairs(T) do
if type(value) == 'table' then
n = n + count_leaves(value)
else
n = n + 1
end
end
return n
end
x = {{1, 2}, {3, 4}}
printx (#x)
printx (count_leaves(x))
printx ({x, x})
printx (#{x, x})
printx (count_leaves({x, x}))
-- Mapping over trees
function scale_tree(factor, T)
local rt = {}
for key,value in pairs(T) do
if type(value) == 'table' then
rt[key] = scale_tree(factor, value)
else
rt[key] = factor * value
end
end
return rt
end
printx (scale_tree(10, {1, {2, {3, 4}, 5}, {6, 7}}))
function scale_tree(factor, T)
return map(
function(sub_tree)
if type(sub_tree) == 'table' then
return scale_tree(factor, sub_tree)
else
return sub_tree * factor
end
end
) (T)
end
printx (scale_tree(10, {1, {2, {3, 4}, 5}, {6, 7}}))
-- Exercise 2.24
printx ({1, {2, {3, 4}}})
-- Exercise 2.25
printx ({1, 3, {5, 7}, 9})
printx ({{7}})
printx ({1, {2, {3, {4, {5, {6, 7}}}}}})
-- Exercise 2.26
x = {1, 2, 3}
y = {4, 5, 6}
printx (append(x, y))
printx ({x, y})
-- Exercise 2.27
function deep_reverse(T)
local rt = {}
for i = #T, 1, -1 do
if type(T[i]) == 'table' then
rt[#rt+1] = deep_reverse(T[i])
else
rt[#rt+1] = T[i]
end
end
return rt
end
x = {{1, 2}, {3, 4}}
printx (x)
printx (reverse(x))
printx (deep_reverse(x))
-- Exercise 2.28
function fringe(T)
local rt = {}
for key,value in pairs(T) do
if type(value) == 'table' then
for key2,value2 in pairs(fringe(value)) do
rt[#rt+1] = value2
end
else
rt[#rt+1] = value
end
end
return rt
end
x = {{1, 2}, {3, 4}}
printx (fringe(x))
printx (fringe({x, x}))
-- Exercise 2.29
-- List-based representation
-- a.
function make_mobile(left, right) return {left=left, right=right} end
function make_branch(length, struc) return {length=length, struc=struc} end
-- Helpers for b. and c.
function branch_weight(branch)
if #branch == 0 then
return 0
elseif type(branch) == 'table' then
return branch_weight(branch.struct)
else
return branch.struc
end
end
function total_branch_length(branch)
if #branch == 0 then
return 0
elseif type(branch) == 'table' then
return branch.length + total_branch_length(branch.struc)
else
return branch.length
end
end
-- b.
function total_weight(mobile)
return branch_weight(mobile.left) + branch_weight(mobile.right)
end
-- c. [Not as per specification]
function is_mobile_balanced(mobile)
local lmwl = total_branch_length(mobile.left) * branch_weight(mobile.left)
local rmwl = total_branch_length(mobile.right) * branch_weight(mobile.right)
return lmwl == rmwl
end
-- Exercise 2.30
function square_tree(T)
local rt = {}
for key,value in pairs(T) do
if type(value) == 'table' then
rt[key] = square_tree(value)
else
rt[key] = value*value
end
end
return rt
end
printx (square_tree({1, {2, {3, 4}, 5}, {6, 7}}))
-- Exercise 2.31
function tree_map(T, proc)
local rt = {}
for key,value in pairs(T) do
if type(value) == 'table' then
rt[key] = tree_map(value, proc)
else
rt[key] = proc(value)
end
end
return rt
end
function square_tree(T)
return tree_map(T, function(x) return x * x end)
end
printx (square_tree({1, {2, {3, 4}, 5}, {6, 7}}))
-- Exercise 2.32
function subsets(T)
if (#T == 0) then
return {{}}
else
local head = T[1]
local tail = slice(T, 2)
local rest = subsets(tail)
return append(rest, map (function(x) return append({head}, x) end) (rest))
end
end
printx (subsets({1, 2, 3}))
-- 2.2.3 Hierarchical Data and the Closure Property - Sequences as Conventional Interfaces
function is_odd(n) return math.mod(n, 2) == 1 end
function is_even(n) return not(is_odd(n)) end
function square(x) return x * x end
function sum_odd_squares(T)
local sum = 0
for key,value in pairs(T) do
if type(value) == 'table' then
sum = sum + sum_odd_squares(value)
elseif is_odd(value) then
sum = sum + square(value)
end
end
return sum
end
function even_fibs(n)
local rt = {}
for i = 1, n, 1 do
local f = fib(i)
if is_even(f) then
rt[#rt+1] = f
end
end
return rt
end
printx (even_fibs(10))
-- Sequence operations
printx (map (square) ({1,2,3,4,5}))
-- non-curried version of filter
function filter(predicate, T)
local rt = {}
for key,value in pairs(T) do
if predicate(value) then
rt[#rt+1] = value
end
end
return rt
end
printx (filter(is_odd, {1,2,3,4,5}))
-- curried version of filter
function filter(predicate)
local function filter_lambda(T)
local rt = {}
for key,value in pairs(T) do
if predicate(value) then
rt[#rt+1] = value
end
end
return rt
end
return filter_lambda
end
printx (filter (is_odd) ({1,2,3,4,5}))
-- non-curried version of accumulate (aka foldl)
function accumulate(oper, initial, T)
local rt = initial
for key,value in pairs(T) do
rt = oper(value, rt)
end
return rt
end
printx (accumulate(function(x,y) return x+y end, 0, {1,2,3,4,5}))
printx (accumulate(function(x,y) return x*y end, 1, {1,2,3,4,5}))
printx (accumulate(function(x,y) table.insert(y, x); return y end, {}, {1,2,3,4,5}))
-- curried version of accumulate (aka foldl)
function accumulate(oper)
local function initial_lambda(initial)
local function sequence_lambda(T)
local rt = initial
for key,value in pairs(T) do
rt = oper(value, rt)
end
return rt
end
return sequence_lambda
end
return initial_lambda
end
printx (accumulate (function(x,y) return x+y end) (0) ({1,2,3,4,5}))
printx (accumulate (function(x,y) return x*y end) (1.0) ({1,2,3,4,5}))
printx (accumulate (function(x,y) table.insert(y, x); return y end) ({}) ({1,2,3,4,5}))
function enumerate_interval(low, high)
local rt = {}
for i = low, high, 1 do
rt[#rt+1] = i
end
return rt
end
printx (enumerate_interval(2,7))
function enumerate_tree(T)
local rt = {}
for key,value in pairs(T) do
if type(value) == 'table' then
for key2,value2 in pairs(enumerate_tree(value)) do
rt[#rt+1] = value2
end
else
rt[#rt+1] = value
end
end
return rt
end
printx (enumerate_tree({1, {2, {3, 4}, 5}}))
function sum_odd_squares(tree)
return
accumulate (
function(x,y) return x+y end) (
0) (
map (square) (filter (is_odd) (enumerate_tree(tree))))
end
function even_fibs(n)
return accumulate (function(x,y) table.insert(y, x); return y end) ({}) (filter (is_even) (map (fib) (enumerate_interval(0, n))))
end
function list_fib_squares(n)
return accumulate (function(x,y) table.insert(y, x); return y end) ({}) (map (square) (map (fib) (enumerate_interval(0, n))))
end
printx (list_fib_squares(10))
function product_of_squares_of_odd_elements(sequence)
return accumulate (function(x,y) return x*y end) (1) (map (square) (filter (is_odd) (sequence)))
end
printx (product_of_squares_of_odd_elements({1,2,3,4,5}))
function Employee(init_empname, init_jobtitle, init_salary)
local emp = {}
emp.empname = init_empname
emp.jobtitle = init_jobtitle
emp.salary = init_salary
return emp
end
function isProgrammer(emp)
return emp.jobtitle == "Programmer"
end
function getSalary(emp)
return emp.salary
end
function salary_of_highest_paid_programmer(records)
return accumulate (math.max) (0) (map (getSalary) (filter (isProgrammer) (records)))
end
recs = {Employee("Fred", "Programmer", 180),
Employee("Hank", "Programmer", 150)}
printx (salary_of_highest_paid_programmer(recs))
-- Nested mappings
n = 5 -- book doesn't define n
printx (accumulate (append) ({}) (
map (
function(i)
return map (function(j) return {i,j} end) (enumerate_interval(1, i-1))
end) (
enumerate_interval(1, n))))
function flatmap(proc)
return
function(seq)
return accumulate (append) ({}) (map (proc) (seq))
end
end
function has_no_divisors(n, c)
if (c == 1) then
return true
elseif (math.mod(n, c) == 0) then
return false
else
return has_no_divisors(n, c-1)
end
end
function is_prime(n)
return has_no_divisors(n, n-1)
end
function prime_sum(pair)
return is_prime(pair.x + pair.y)
end
function make_pair_sum(pair)
return {x = pair.x, y = pair.y, sum = pair.x + pair.y}
end
function prime_sum_pairs(n)
return
map (make_pair_sum)(
filter (
prime_sum) (
flatmap (
function(i) return map (function(j) return {x=i,y=j} end) (enumerate_interval(1, i-1)) end) (
enumerate_interval(1, n))))
end
printx (prime_sum_pairs(15))
function remove(item, sequence)
return filter (function(x) return x ~= item end) (sequence)
end
function permutations(T)
if (#T == 0) then
return {{}}
else
return
flatmap (
function(x)
return map (function(a) return append(a, {x}) end) (permutations(remove(x, T)))
end) (T)
end
end
printx (permutations({1,2,3}))
-- Exercise 2.34
-- exercise left to reader to define appropriate functions
-- function horner_eval(x, coefficient_sequence)
-- return accumulate (function(this_coeff, higher_terms) return ??FILL_THIS_IN?? end) (0) (coefficient_sequence)
-- end
-- horner_eval(2, {1,3,0,5,0,1}))
-- Exercise 2.36
-- exercise left to reader to define appropriate functions
-- function accumulate_n(oper)
-- local function initial_lambda(initial)
-- local function sequence_lambda(sequence)
-- if (sequence == nil) then
-- return initial
-- else
-- return {accumulate (oper) (init) (??FILL_THIS_IN??),
-- accumulate_n (oper) (init) (??FILL_THIS_IN??)}
-- end
-- end
-- return sequence_lambda
-- end
-- return initial_lambda
-- end
-- accumulate_n (function(x,y) return x + y end) (0) (s)
-- Exercise 2.38
fold_right = accumulate
function fold_left(oper)
local function initial_lambda(initial)
local function sequence_lambda(T)
local rt = initial
for i = #T, 1, -1 do
rt = oper(rt, T[i])
end
return rt
end
return sequence_lambda
end
return initial_lambda
end
printx (fold_right (function(x,y) return x/y end) (1.0) ({1,2,3}))
printx (fold_left (function(x,y) return x/y end) (1.0) ({1,2,3}))
printx (fold_right (function(x,y) table.insert(y, x); return y end) ({}) ({1,2,3}))
printx (fold_left (function(x,y) table.insert(x, y); return x end) ({}) ({1,2,3}))
-- Exercise 2.42
-- exercise left to reader to define appropriate functions
-- function queens(board_size)
-- local function queen_cols(k)
-- if (k == 0) then
-- return {empty_board}
-- else
-- return (
-- filter (
-- function(positions) return isSafe(k, positions) end) (
-- flatmap (
-- function(rest_of_queens)
-- return
-- map (
-- function(new_row) return adjoin_position(new_row, k, rest_of_queens) end) (
-- enumerate_interval(1, board_size))
-- end) (
-- queen_cols(k-1))))
-- end
-- end
-- return queen_cols(board_size)
-- end
-- Exercise 2.43
-- exercise left to reader to define appropriate functions
-- function queens(board_size)
-- local function queen_cols(k)
-- if (k == 0) then
-- return {empty_board}
-- else
-- return (
-- filter (
-- function(positions) return isSafe(k, positions) end) (
-- flatmap (
-- function(new_row)
-- return
-- map (
-- function(rest_of_queens) return adjoin_position(new_row, k, rest_of_queens) end) (
-- queen_cols(k-1))
-- end) (
-- enumerate_interval(1, board_size))))
-- end
-- end
-- return queen_cols(board_size)
-- end
-- 2.2.4 Hierarchical Data and the Closure Property - Example: a picture language
-- these two routines are to be written
function draw_line(x, y) end
function wave(xframe) return xframe end
function Vect(init_x, init_y)
local a = {}
a.x = init_x
a.y = init_y
return a
end
function make_vect(x, y) return Vect(x, y) end
function xcor_vect(v) return v.x end
function ycor_vect(v) return v.y end
function add_vect(v1, v2)
return make_vect(xcor_vect(v1) + xcor_vect(v2), ycor_vect(v1) + ycor_vect(v2))
end
function sub_vect(v1, v2)
return make_vect(xcor_vect(v1) - xcor_vect(v2), ycor_vect(v1) - ycor_vect(v2))
end
function scale_vect(s, v)
return make_vect(s * xcor_vect(v), s * ycor_vect(v))
end
function Frame(init_orig, init_edge1, init_edge2)
local a = {}
a.orig = init_orig
a.edge1 = init_edge1
a.edge2 = init_edge2
return a
end
function make_frame(origin, edge1, edge2)
return Frame(origin, edge1, edge2)
end
function origin_frame(f) return f.orig end
function edge1_frame(f) return f.edge1 end
function edge2_frame(f) return f.edge2 end
a_frame = make_frame(make_vect(0, 0), make_vect(1, 0), make_vect(0, 1))
function Segment(init_x, init_y)
local a = {}
a.x = init_x
a.y = init_y
return a
end
function start_segment(seg) return seg.x end
function end_segment(seg) return seg.y end
-- Frames
function frame_coord_map(xframe, v)
return add_vect(
origin_frame(xframe),
add_vect(scale_vect(xcor_vect(v), edge1_frame(xframe)),
scale_vect(ycor_vect(v), edge2_frame(xframe))))
end
frame_coord_map(a_frame, make_vect(0, 0))
origin_frame(a_frame)
-- Painters
function foreach(f)
local function foreach_lambda(T)
for key,value in pairs(T) do
f(value)
end
end
return foreach_lambda
end
function segments_painter(segment_list, xframe)
foreach (
function(segment)
draw_line(
frame_coord_map (xframe) (start_segment, segment),
frame_coord_map (xframe) (end_segment, segment))
end) (
segment_list)
end
function transform_painter(painter, origin, corner1, corner2)
local function transform_painter_lambda(xframe)
local m = frame_coord_map(xframe)
local new_origin = m(origin)
return painter(
make_frame(
new_origin,
sub_vect(m(corner1), new_origin),
sub_vect(m(corner2), new_origin)))
end
return transform_painter_lambda
end
function flip_vert(painter)
return transform_painter(
painter,
make_vect(0, 1),
make_vect(1, 1),
make_vect(0, 0))
end
function flip_horiz(painter)
return transform_painter(
painter,
make_vect(1, 0),
make_vect(0, 0),
make_vect(1, 1))
end
function shrink_to_upper_right(painter)
return transform_painter(
painter,
make_vect(0.5, 0.5),
make_vect(1, 0.5),
make_vect(0.5, 1))
end
function rotate90(painter)
return transform_painter(
painter,
make_vect(1, 0),
make_vect(1, 1),
make_vect(0, 0))
end
function rotate180(painter)
return transform_painter(
painter,
make_vect(1, 1),
make_vect(0, 1),
make_vect(1, 0))
end
function squash_inwards(painter)
return transform_painter(
painter,
make_vect(0, 0),
make_vect(0.65, 0.35),
make_vect(0.35, 0.65))
end
function beside(painter1, painter2)
local function beside_lambda(xframe)
local split_point = make_vect(0.5, 0)
local paint_left = (
transform_painter(
painter1,
make_vect(0, 0),
split_point,
make_vect(0, 1)))
local paint_right = (
transform_painter(
painter2,
split_point,
make_vect(1, 0),
make_vect(0.5, 1)))
paint_left(xframe)
paint_right(xframe)
end
return beside_lambda
end
function below(painter1, painter2)
local function below_lambda(xframe)
local split_point = make_vect(0, 0.5)
local paint_below = (
transform_painter(
painter1,
make_vect(0, 0),
make_vect(1, 0),
split_point))
local paint_above = (
transform_painter(
painter2,
split_point,
make_vect(1, 0.5),
make_vect(0, 1)))
paint_below(xframe)
paint_above(xframe)
end
return below_lambda
end
function up_split(painter, n)
if (n == 0) then
return painter
else
local smaller = up_split(painter, n-1)
return below(painter, beside(smaller, smaller))
end
end
wave2 = beside(wave, flip_vert(wave))
wave4 = below(wave2, wave)
function flipped_pairs(painter)
local painter2 = beside(painter, flip_vert(painter))
return below(painter2, painter2)
end
wave4 = flipped_pairs(wave)
function right_split(painter, n)
if (n == 0) then
return painter
else
local smaller = right_split(painter, n-1)
return beside(painter, below(smaller, smaller))
end
end
function corner_split(painter, n)
if (n == 0) then
return painter
else
local up = up_split(painter, n-1)
local right = right_split(painter, n-1)
local top_left = beside(up, up)
local bottom_right = below(right, right)
local corner = corner_split(painter, n-1)
return beside(below(painter, top_left), below(bottom_right, corner))
end
end
function square_limit(painter, n)
local quarter = corner_split(painter, n)
local half = beside(flip_horiz(quarter), quarter)
return below(flip_vert(half), half)
end
-- Higher_order operations
function square_of_four(tleft, tright, bleft, bright)
local function square_of_four_lambda(painter)
local top = beside(tleft(painter), tright(painter))
local bottom = beside(bright(painter), bright(painter))
return below(bottom, top)
end
return square_of_four_lambda
end
function flipped_pairs(painter)
local combine4 = square_of_four(identity, flip_vert, identity, flip_vert)
return combine4(painter)
end
-- footnote
flipped_pairs = square_of_four(identity, flip_vert, identity, flip_vert)
function square_limit(painter, n)
local combine4 = square_of_four(flip_horiz, identity, rotate180, flip_vert)
return combine4(corner_split(painter, n))
end
-- Exercise 2.45
-- exercise left to reader to define appropriate functions
-- right_split = split(beside, below)
-- up_split = split(below, beside)
-- Exercise 2.47
function make_frame(origin, edge1, edge2)
return {origin, edge1, edge2}
end
function make_frame(origin, edge1, edge2)
return {origin, {edge1, edge2}}
end
-- 2.3.1 Symbolic Data - Quotation
-- To Be Done.
-- 2.3.2 Symbolic Data - Example: Symbolic Differentiation
function is_same_number(x, y)
return type(x) == 'number' and type(y) == 'number' and x == y
end
function is_variable(x)
return type(x) == 'string'
end
function is_same_variable(x, y)
return is_variable(x) and is_variable(y) and x == y
end
function is_sum(T)
return T.tag == 'sum'
end
function is_product(T)
return T.tag == 'product'
end
function make_sum(x, y)
if type(x) == 'number' and type(y) == 'number' then
return x + y
else
return {tag='sum', add_end=x, aug_end=y}
end
end
function make_product(x, y)
if type(x) == 'number' and type(y) == 'number' then
return x * y
else
return {tag='product', multiplier=x, multiplicand=y}
end
end
function add_end(T)
if is_sum(T) then
return T.add_end
else
error('Invalid pattern match ' .. tostring(T))
end
end
function aug_end(T)
if is_sum(T) then
return T.aug_end
else
error('Invalid pattern match ' .. tostring(T))
end
end
function multiplier(T)
if is_product(T) then
return T.multiplier
else
error('Invalid pattern match ' .. tostring(T))
end
end
function multiplicand(T)
if is_product(T) then
return T.multiplicand
else
error('Invalid pattern match ' .. tostring(T))
end
end
function deriv(exp, var)
if type(exp) == 'number' then
return 0
elseif is_variable(exp) then
if is_same_variable(exp, var) then
return 1
else
return 0
end
elseif is_sum(exp) then
return make_sum(deriv(add_end(exp), var),
deriv(aug_end(exp), var))
elseif is_product(exp) then
return make_sum(make_product(multiplier(exp), deriv(multiplicand(exp), var)),
make_product(deriv(multiplier(exp), var), multiplicand(exp)))
else
error('Invalid expression ' .. tostring(exp))
end
end
-- dx(x + 3) = 1
printx (deriv({tag='sum', add_end='x', aug_end=3}, 'x'))
-- -- dx(x*y) = y
printx(deriv({tag='product', multiplier='x', multiplicand='y'}, 'x'))
-- dx(x*y + x + 3) = y + 1
printx(deriv({tag='sum', add_end={tag='sum', add_end={tag='product', multiplier='x', multiplicand='y'}, aug_end='x'}, aug_end=3}, 'x'))
-- With simplification
function make_sum(x, y)
if type(x) == 'number' and x == 0 then
return y
elseif type(y) == 'number' and y == 0 then
return x
elseif type(x) == 'number' and type(y) == 'number' then
return x + y
else
return {tag='sum', add_end=x, aug_end=y}
end
end
function make_product(x, y)
if type(x) == 'number' and x == 0 then
return 0
elseif type(y) == 'number' and y == 0 then
return 0
elseif type(x) == 'number' and x == 1 then
return y
elseif type(y) == 'number' and y == 1 then
return x
elseif type(x) == 'number' and type(y) == 'number' then
return x * y
else
return {tag='product', multiplier=x, multiplicand=y}
end
end
function deriv(exp, var)
if type(exp) == 'number' then
return 0
elseif is_variable(exp) then
if is_same_variable(exp, var) then
return 1
else
return 0
end
elseif is_sum(exp) then
return make_sum(deriv(add_end(exp), var),
deriv(aug_end(exp), var))
elseif is_product(exp) then
return make_sum(make_product(multiplier(exp), deriv(multiplicand(exp), var)),
make_product(deriv(multiplier(exp), var), multiplicand(exp)))
else
error('Invalid expression ' .. tostring(exp))
end
end
-- dx(x + 3) = 1
printx (deriv({tag='sum', add_end='x', aug_end=3}, 'x'))
-- -- dx(x*y) = y
printx(deriv({tag='product', multiplier='x', multiplicand='y'}, 'x'))
-- dx(x*y + x + 3) = y + 1
printx(deriv({tag='sum', add_end={tag='sum', add_end={tag='product', multiplier='x', multiplicand='y'}, aug_end='x'}, aug_end=3}, 'x'))
-- 2.3.3 Symbolic Data - Example: Representing Sets
-- unordered
function is_element_of_set(x, T)
for key,value in pairs(T) do
if x == value then
return true
end
end
return false
end
function adjoin_set(x, T)
if is_element_of_set(x, T) then
return T
else
return append({x}, T)
end
end
function intersection_set(T1, T2)
local rt = {}
for key,value in pairs(T1) do
if is_element_of_set(value, T2) then
rt[#rt+1] = value
end
end
return rt
end
-- ordered
function is_element_of_set(x, T)
for key,value in pairs(T) do
if x == value then
return true
elseif x < value then
return false
end
end
return false
end
function intersection_set(T1, T2)
local rt = {}
local i = 1
local j = 1
while i <= #T1 and j <= #T2 do
if T1[i] == T2[j] then
rt[#rt+1] = T1[i]
i = i + 1
j = j + 1
elseif T1[i] < T2[j] then
i = i + 1
else
j = j + 1
end
end
return rt
end
function is_element_of_set(x, node)
if node == nil then
return false
else
if x == node.value then
return true
elseif x < node.value then
return is_element_of_set(x, node.left)
else
return is_element_of_set(x, node.right)
end
end
end
printx (is_element_of_set(3,
{value=2,
left ={value=1, left=nil, right=nil},
right={value=3, left=nil, right=nil}}))
function adjoin_set(x, node)
if node == nil then
return {value=x, left=nil, right=nil}
else
if x == node.value then
return node
elseif x < node.value then
return {value=node.value, left=adjoin_set(x, node.left), right=node.right}
else
return {value=node.value, left=node.left, right=adjoin_set(x, node.right)}
end
end
end
printx (
adjoin_set(
3,
{value=4,
left ={value=2, left=nil, right=nil},
right={value=6, left=nil, right=nil}}))
-- Exercise 2.63
function tree_to_list(node)
if node == nil then
return {}
else
return append(tree_to_list(node.left),
append(
{node.value},
tree_to_list(node.right)))
end
end
printx (
tree_to_list(
{value=4,
left ={value=2, left=nil, right=nil},
right={value=6, left=nil, right=nil}}))
function tree_to_list(node)
local function copy_to_list(node, xs)
if node == nil then
return xs
else
return copy_to_list(node.left, append({node.value}, copy_to_list(node.right, xs)))
end
end
return copy_to_list(node, {})
end
printx (
tree_to_list(
{value=4,
left ={value=2, left=nil, right=nil},
right={value=6, left=nil, right=nil}}))
-- Exercise 2.64
function partial_tree(elts, n)
if n == 0 then
return {foo=nil, bar=elts}
else
local left_size = math.floor((n-1) / 2)
local right_size = n - (left_size + 1)
local left_result = partial_tree(elts, left_size)
local left_tree = left_result.foo
local non_left_elts = left_result.bar
local this_entry = non_left_elts[1]
local right_result = partial_tree(slice(non_left_elts, 2), right_size)
local right_tree = right_result.foo
local remaining_elts = right_result.bar
return {foo={value=this_entry, left=left_tree, right=right_tree}, bar=remaining_elts}
end
end
function list_to_tree(elements)
result = partial_tree(elements, #elements)
return result.foo
end
printx (list_to_tree({2, 4, 6}))
-- information retrieval
function lookup(given_key, T)
for key,value in pairs(T) do
if given_key == value.key then
return value.value
end
end
return nil
end
print (lookup(2, {{key=1, value='a'}, {key=2, value='b'}, {key=3, value='c'}}))
-- 2.3.4 Symbolic Data - Example: Huffman Encoding Trees
function make_leaf(symbol, weight)
return {tag='leaf', symbol=symbol, weight=weight}
end
function is_leaf(node)
return node.tag == 'leaf'
end
function symbol_leaf(node)
if is_leaf(node) then
return node.symbol
else
error('Invalid pattern match ' .. tostring(node))
end
end
function weight_leaf(node)
if is_leaf(node) then
return node.weight
else
error('Invalid pattern match ' .. tostring(node))
end
end
function symbols(node)
if is_leaf(node) then
return {node.symbol}
else
return node.subsymbols
end
end
function weight(node)
return node.weight
end
function make_code_tree(left, right)
return {tag='tree',
subsymbols=append(symbols(left), symbols(right)),
weight=weight(left) + weight(right),
left=left,
right=right}
end
function left_node(node)
if not(is_leaf(node)) then
return node.left
else
error('Invalid pattern match ' .. tostring(node))
end
end
function right_node(node)
if not(is_leaf(node)) then
return node.right
else
error('Invalid pattern match ' .. tostring(node))
end
end
function choose_node(n, node)
if n == 0 then
return left_node(node)
elseif n == 1 then
return right_node(node)
else
error('Invalid pattern match ' .. tostring(n))
end
end
-- decoding
function decode(bits, tree)
function decode_1(bits, current_node)
if #bits == 0 then
return {}
else
local head = bits[1]
local tail = slice(bits, 2)
local next_node = choose_node(head, current_node)
if is_leaf(next_node) then
return append({symbol_leaf(next_node)}, decode_1(tail, tree))
else
return decode_1(tail, next_node)
end
end
end
return decode_1(bits, tree)
end
-- sets
function adjoin_set(x, T)
if node == nil then
return {x}
else
local head = T[1]
if weight(x) < weight(head) then
return append({x}, T)
else
local tail = slice(T, 2)
return append(head, adjoin_set(x, tail))
end
end
end
function make_leaf_set(node)
local head = node[1]
local tail = slice(node, 2)
if is_leaf(head) then
return adjoin_set(make_leaf(symbol_leaf(head), symbol_weight(head)), make_leaf_set(tail))
else
error('Invalid pattern match ' .. table_to_string(node))
end
end
-- Exercise 2.67
sample_tree = make_code_tree(
make_leaf('A', 4),
make_code_tree(
make_leaf('B', 2),
make_code_tree(
make_leaf('D', 1),
make_leaf('C', 1))))
sample_message = {0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0}
printx (decode(sample_message, sample_tree))
-- Exercise 2.68
-- exercise left to reader to define appropriate functions
-- function encode(message, tree)
-- if #message == 0 then
-- return {}
-- else
-- local head = message[1]
-- local tail = slice(message, 2)
-- return append(encode_symbol(head, tree), encode(tail, tree))
-- end
-- end
-- 2.4.1 Multiple Representations for Abstract Data - Representations for Complex Numbers
-- Same as above
function square(x) return x * x end
-- Rectangular
function real_part_r(z) return z.real end
function imag_part_r(z) return z.imag end
function magnitude_r(z)
return math.sqrt(square(real_part_r(z)) + square(imag_part_r(z)))
end
function angle_r(z)
return math.atan2(imag_part_r(z), real_part_r(z))
end
function make_from_real_imag_r(x, y) return {real=x, imag=y} end
function make_from_mag_ang_r(r, a)
return {real=r*math.cos(a), imag=r*math.sin(a)}
end
-- polar
function magnitude_p(z) return z.magnitude end
function angle_p(z) return z.angle end
function real_part_p(z)
return magnitude_p(z) * math.cos(angle_p(z))
end
function imag_part_p(z)
return magnitude_p(z) * math.sin(angle_p(z))
end
function make_from_real_imag_p(x, y)
return {magnitude=math.sqrt(square(x) + square(y)), angle=math.atan2(y, x)}
end
function make_from_mag_ang_p(r, a)
return {magnitude=r, angle=a}
end
-- using the abstract type
magnitude = magnitude_r
angle = angle_r
real_part = real_part_r
imag_part = imag_part_r
make_from_real_imag = make_from_real_imag_r
make_from_mag_ang = make_from_mag_ang_r
z = {real=1, imag=2}
printx (make_from_real_imag(real_part(z), imag_part(z)))
printx (make_from_mag_ang(magnitude(z), angle(z)))
function add_complex(z1, z2)
return make_from_real_imag(
real_part(z1) + real_part(z2),
imag_part(z1) + imag_part(z2))
end
function sub_complex(z1, z2)
return make_from_real_imag(
real_part(z1) - real_part(z2),
imag_part(z1) - imag_part(z2))
end
function mul_complex(z1, z2)
return make_from_mag_ang(
magnitude(z1) * magnitude(z2),
angle(z1) + angle(z2))
end
function div_complex(z1, z2)
return make_from_mag_ang(
magnitude(z1) / magnitude(z2),
angle(z1) - angle(z2))
end
-- 2.4.2 Multiple Representations for Abstract Data - Tagged Data
function attach_tag(type_tag, contents)
return {tag=type_tag, contents=contents}
end
function type_tag(a)
if a.tag == 'rectangular' then
return 'rectangular'
elseif a.tag == 'polar' then
return 'polar'
else
error('Invalid pattern match ' .. table_to_string(a))
end
end
function contents(a)
if a.tag == 'rectangular' then
return a.contents
elseif a.tag == 'polar' then
return a.contents
else
error('Invalid pattern match ' .. table_to_string(a))
end
end
function is_rectangular(a)
return type_tag(a) == 'rectangular'
end
function is_polar(a)
return type_tag(a) == 'polar'
end
-- Rectangular
function make_from_real_imag_rectangular(x, y)
return attach_tag('rectangular', {real=x, imag=y})
end
function make_from_mag_ang_rectangular(r, a)
return attach_tag('rectangular', {real=r*math.cos(a), imag=r*math.sin(a)})
end
function real_part_rectangular(z)
return z.real
end
function imag_part_rectangular(z)
return z.imag
end
function magnitude_rectangular(z)
return math.sqrt(square(real_part_rectangular(z)) + square(imag_part_rectangular(z)))
end
function angle_rectangular(z)
math.atan2(imag_part_rectangular(z), real_part_rectangular(z))
end
-- Polar
function make_from_real_imag_polar(x, y)
return attach_tag('polar', {magniture=math.sqrt(square(x) + square(y)), angle=math.atan2(y, x)})
end
function make_from_mag_ang_polar(r, a)
return attach_tag('polar', {magniture=r, angle=a})
end
function magnitude_polar(z)
return z.magnitude
end
function angle_polar(z)
return z.angle
end
function real_part_polar(z)
return magnitude_polar(z) * math.cos(angle_polar(z))
end
function imag_part_polar(z)
return magnitude_polar(z) * math.sin(angle_polar(z))
end
-- Generic selectors
function real_part_g(a)
if type_tag(a) == 'rectangular' then
return real_part_rectangular(contents(a))
elseif type_tag(a) == 'polar' then
return real_part_polar(contents(a))
else
error('Invalid pattern match ' .. table_to_string(a))
end
end
function imag_part_g(a)
if type_tag(a) == 'rectangular' then
return imag_part_rectangular(contents(a))
elseif type_tag(a) == 'polar' then
return imag_part_polar(contents(a))
else
error('Invalid pattern match ' .. table_to_string(a))
end
end
function magnitude_g(a)
if type_tag(a) == 'rectangular' then
return magnitude_rectangular(contents(a))
elseif type_tag(a) == 'polar' then
return magnitude_polar(contents(a))
else
error('Invalid pattern match ' .. table_to_string(a))
end
end
function angle_g(a)
if type_tag(a) == 'rectangular' then
return angle_rectangular(contents(a))
elseif type_tag(a) == 'polar' then
return angle_polar(contents(a))
else
error('Invalid pattern match ' .. table_to_string(a))
end
end
-- Constructors for complex numbers
function make_from_real_imag_g(x, y)
return make_from_real_imag_rectangular(x, y)
end
function make_from_mag_ang_g(r, a)
return make_from_mag_ang_polar(r, a)
end
-- same as before
function add_complex_g(z1, z2)
return make_from_real_imag_g(
real_part_g(z1) + real_part_g(z2),
imag_part_g(z1) + imag_part_g(z2))
end
function sub_complex_g(z1, z2)
return make_from_real_imag_g(
real_part_g(z1) - real_part_g(z2),
imag_part_g(z1) - imag_part_g(z2))
end
function mul_complex_g(z1, z2)
return make_from_mag_ang_g(
magnitude_g(z1) * magnitude_g(z2),
angle_g(z1) + angle_g(z2))
end
function div_complex_g(z1, z2)
return make_from_mag_ang_g(
magnitude_g(z1) / magnitude_g(z2),
angle_g(z1) - angle_g(z2))
end
printx (add_complex_g(make_from_real_imag_g(3, 4), make_from_real_imag_g(3, 4)))
-- 2.4.3 Multiple Representations for Abstract Data - Data-Directed Programming and Additivity
-- To Be Done.
-- 2.5.1 Systems with Generic Operations - Generic Arithmetic Operations
-- To Be Done.
|