#!/usr/bin/env ruby
#
#  Created by Reginald Braithwaite on 2007-03-14.
#  Copyright (c) 2007. All rights reserved.

#            DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
#                    Version 2, December 2004
# 
# Copyright (C) 2004 Sam Hocevar
#  22 rue de Plaisance, 75014 Paris, France
# Everyone is permitted to copy and distribute verbatim or modified
# copies of this license document, and changing it is allowed as long
# as the name is changed.
# 
#            DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
#   TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
# 
#  0. You just DO WHAT THE FUCK YOU WANT TO.

# A Comprehension produces lists. It isn't just a syntactic form,
# it is an object that can be inspected and modified or built
# dynamically.

require 'dsl'

class Comprehension
  
  class << self
    def combinations_by_sum n, *lists
      if lists.size == 1
        if n >= lists.first.size
          []
        else
          [ [ lists.first[n] ] ]
        end
      else
        ns = (0..(n < lists.first.size ? n : lists.first.size - 1))
        ns.map { |n2| [n2, lists.first[n2]] }.inject([]) do |acc, pair| 
          n3, first_list_element = *pair
          acc + combinations_by_sum(n - n3, *(lists[1..-1])).map { |ec| [first_list_element] + ec }
        end
      end 
    end
    
    def cartesian_product_from n, *lists, &block
      by_sum = combinations_by_sum(n, *lists)
      if by_sum.empty?
        []
      elsif block_given?
        (by_sum + cartesian_product_from(n + 1, *lists)).map { |tuple| block.call(*tuple) }
      else
        by_sum + cartesian_product_from(n + 1, *lists)
      end
    end
    
    def cartesian_product *lists, &block
      arrs = lists.map do |given|
        if given.respond_to?(:[]) && given.respond_to?(:empty?)
          given
        else
          given.to_a
        end
      end
      if arrs.empty?
        [] # the product of no lists is empty
      elsif arrs.detect { |each_list| each_list.empty? }
        [] # the product of any set of lists containing an empty list is empty
      elsif block_given?  
        cartesian_product_from(0, *arrs, &block)
      else    
        cartesian_product_from(0, *arrs)
      end
    end
  end
  
  attr_accessor :mapping, :terms, :where
  
  def initialize &block
    @terms = {}
    @mapping = block if block_given?
    @where = nil
  end
  
  def given assignments = {}, &block
    @where = block if block_given?
    assignments.each { |term, value| @terms[term] = value }
    names = terms.keys.sort { |a, b| a.to_s <=> b.to_s }
    values = names.map { |term| terms[term].map { |value| { term => value } } }
    product_of_assignments = self.class.cartesian_product(*values).map do |list_of_assignments|
      list_of_assignments.inject { |acc, ass| acc.merge(ass) }
    end
    if where
      product_of_assignments = product_of_assignments.select do |assignments|
        with Let do
          let(assignments, &where)
        end
      end
    end
    if mapping
      product_of_assignments.map do |assignments|
        with Let do
          let(assignments, &mapping)
        end
      end
    else
      product_of_assignments
    end
  end
  
  class DSL < DomainSpecificLanguage
    def list &block
      Comprehension.new &block
    end
  end
  
end

# ==============
# = Test Cases =
# ==============

require "test/unit"

class TestComprehension < Test::Unit::TestCase
  
  def test_simple_cases
    with Comprehension::DSL do
      assert_equal(
          [1, 2, 3],
          list { x + 1 }.given(:x => 0..2) )
      assert_equal(
          [[0, :a], [0, :b], [1, :a], [1, :b]],
          list { [x, y] }.given(:x => [0, 1], :y => [:a, :b]) )
      assert_equal(
          [2, 4, 6, 8, 10],
          list { x }.given(:x => 1..10) { x % 2 == 0 } )
    end
  end
  
  def plus_seven num
    num + 7
  end
  
  def times_maker num
    lambda { x * num }
  end
  
  def test_closure_cases
    two = 2
    triple = times_maker(3)
    with Comprehension::DSL do
      assert_equal(
          [2, 4, 6],
          list { x * two }.given(:x => 1..3) )
      assert_equal(
          [3, 6, 9],
          list(&triple).given(:x => 1..3) )
      assert_equal(
          [8, 9, 10],
          list { plus_seven(x) }.given(:x => 1..3) )
    end
  end
  
end