diff --git a/lib/oga/xpath/compiler.rb b/lib/oga/xpath/compiler.rb index 6804245..1013a48 100644 --- a/lib/oga/xpath/compiler.rb +++ b/lib/oga/xpath/compiler.rb @@ -299,7 +299,7 @@ module Oga # # @see [#operator] # - def on_eq(ast, input) + def on_eq(ast, input, &block) conversion = literal('Conversion') operator(ast, input) do |left, right| @@ -308,7 +308,10 @@ module Oga conversion.to_compatible_types(left, right) ) - compatible_assign.followed_by(left.eq(right)) + operation = left.eq(right) + operation = operation.if_true(&block) if block # In a predicate + + compatible_assign.followed_by(operation) end end @@ -317,7 +320,7 @@ module Oga # # @see [#operator] # - def on_neq(ast, input) + def on_neq(ast, input, &block) conversion = literal('Conversion') operator(ast, input) do |left, right| @@ -326,19 +329,28 @@ module Oga conversion.to_compatible_types(left, right) ) - compatible_assign.followed_by(left != right) + operation = left != right + operation = operation.if_true(&block) if block # In a predicate + + compatible_assign.followed_by(operation) end end OPERATORS.each do |callback, (conv_method, ruby_method)| - define_method(callback) do |ast, input| + define_method(callback) do |ast, input, &block| conversion = literal('Conversion') operator(ast, input) do |left, right| - lval = conversion.__send__(conv_method, left) - rval = conversion.__send__(conv_method, right) + lval = conversion.__send__(conv_method, left) + rval = conversion.__send__(conv_method, right) + operation = lval.__send__(ruby_method, rval) - lval.__send__(ruby_method, rval) + # In a predicate + if block + operation = conversion.to_boolean(operation).if_true(&block) + end + + operation end end end @@ -348,10 +360,11 @@ module Oga # # @see [#operator] # - def on_pipe(ast, input) + def on_pipe(ast, input, &block) left, right = *ast - union = unique_literal('union') + union = unique_literal('union') + conversion = literal('Conversion') left_push = process(left, input) do |node| union << node @@ -361,10 +374,18 @@ module Oga union << node end - union.assign(literal(XML::NodeSet).new) + push_ast = union.assign(literal(XML::NodeSet).new) .followed_by(left_push) .followed_by(right_push) - .followed_by(union) + + # In a predicate + if block + final = conversion.to_boolean(union).if_true(&block) + else + final = union + end + + push_ast.followed_by(final) end # @param [AST::Node] ast