1// Licensed to the Software Freedom Conservancy (SFC) under one2// or more contributor license agreements. See the NOTICE file3// distributed with this work for additional information4// regarding copyright ownership. The SFC licenses this file5// to you under the Apache License, Version 2.0 (the6// "License"); you may not use this file except in compliance7// with the License. You may obtain a copy of the License at8//9// http://www.apache.org/licenses/LICENSE-2.010//11// Unless required by applicable law or agreed to in writing,12// software distributed under the License is distributed on an13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY14// KIND, either express or implied. See the License for the15// specific language governing permissions and limitations16// under the License.17package org.openqa.selenium.remote;18import net.bytebuddy.ByteBuddy;19import net.bytebuddy.description.annotation.AnnotationDescription;20import net.bytebuddy.dynamic.DynamicType;21import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;22import net.bytebuddy.implementation.FixedValue;23import net.bytebuddy.implementation.MethodDelegation;24import org.openqa.selenium.Beta;25import org.openqa.selenium.Capabilities;26import org.openqa.selenium.HasCapabilities;27import org.openqa.selenium.ImmutableCapabilities;28import org.openqa.selenium.WebDriver;29import org.openqa.selenium.WebDriverException;30import org.openqa.selenium.WrapsDriver;31import org.openqa.selenium.internal.Require;32import org.openqa.selenium.remote.html5.AddApplicationCache;33import org.openqa.selenium.remote.html5.AddLocationContext;34import org.openqa.selenium.remote.html5.AddWebStorage;35import org.openqa.selenium.remote.mobile.AddNetworkConnection;36import java.lang.reflect.Field;37import java.lang.reflect.Modifier;38import java.util.HashSet;39import java.util.List;40import java.util.ServiceLoader;41import java.util.Set;42import java.util.function.BiFunction;43import java.util.function.Predicate;44import java.util.stream.Collectors;45import java.util.stream.Stream;46import java.util.stream.StreamSupport;47import static java.util.Collections.unmodifiableSet;48import static net.bytebuddy.matcher.ElementMatchers.anyOf;49import static net.bytebuddy.matcher.ElementMatchers.named;50/**51 * Enhance the interfaces implemented by an instance of the52 * {@link org.openqa.selenium.WebDriver} based on the returned53 * {@link org.openqa.selenium.Capabilities} of the driver.54 *55 * Note: this class is still experimental. Use at your own risk.56 */57@Beta58public class Augmenter {59 private final Set<Augmentation<?>> augmentations;60 public Augmenter() {61 Set<Augmentation<?>> augmentations = new HashSet<>();62 Stream.of(63 new AddApplicationCache(),64 new AddLocationContext(),65 new AddNetworkConnection(),66 new AddRotatable(),67 new AddWebStorage()68 ).forEach(provider -> augmentations.add(createAugmentation(provider)));69 StreamSupport.stream(ServiceLoader.load(AugmenterProvider.class).spliterator(), false)70 .forEach(provider -> augmentations.add(createAugmentation(provider)));71 this.augmentations = unmodifiableSet(augmentations);72 }73 private static <X> Augmentation<X> createAugmentation(AugmenterProvider<X> provider) {74 Require.nonNull("Interface provider", provider);75 return new Augmentation<>(provider.isApplicable(),76 provider.getDescribedInterface(),77 provider::getImplementation);78 }79 private Augmenter(Set<Augmentation<?>> augmentations, Augmentation<?> toAdd) {80 Require.nonNull("Current list of augmentations", augmentations);81 Require.nonNull("Augmentation to add", toAdd);82 Set<Augmentation<?>> toUse = new HashSet<>(augmentations.size() + 1);83 toUse.addAll(augmentations);84 toUse.add(toAdd);85 this.augmentations = unmodifiableSet(toUse);86 }87 public <X> Augmenter addDriverAugmentation(AugmenterProvider<X> provider) {88 Require.nonNull("Interface provider", provider);89 return addDriverAugmentation(90 provider.isApplicable(),91 provider.getDescribedInterface(),92 provider::getImplementation);93 }94 public <X> Augmenter addDriverAugmentation(95 String capabilityName,96 Class<X> implementThis,97 BiFunction<Capabilities, ExecuteMethod, X> usingThis) {98 Require.nonNull("Capability name", capabilityName);99 Require.nonNull("Interface to implement", implementThis);100 Require.nonNull("Concrete implementation", usingThis, "of %s", implementThis);101 return addDriverAugmentation(check(capabilityName), implementThis, usingThis);102 }103 public <X> Augmenter addDriverAugmentation(104 Predicate<Capabilities> whenThisMatches,105 Class<X> implementThis,106 BiFunction<Capabilities, ExecuteMethod, X> usingThis) {107 Require.nonNull("Capability predicate", whenThisMatches);108 Require.nonNull("Interface to implement", implementThis);109 Require.nonNull("Concrete implementation", usingThis, "of %s", implementThis);110 Require.precondition(implementThis.isInterface(), "Expected %s to be an interface", implementThis);111 return new Augmenter(augmentations, new Augmentation<>(whenThisMatches, implementThis, usingThis));112 }113 private Predicate<Capabilities> check(String capabilityName) {114 return caps -> {115 Require.nonNull("Capability name", capabilityName);116 Object value = caps.getCapability(capabilityName);117 if (value instanceof Boolean && !((Boolean) value)) {118 return false;119 }120 return value != null;121 };122 }123 /**124 * Enhance the interfaces implemented by this instance of WebDriver iff that instance is a125 * {@link org.openqa.selenium.remote.RemoteWebDriver}.126 *127 * The WebDriver that is returned may well be a dynamic proxy. You cannot rely on the concrete128 * implementing class to remain constant.129 *130 * @param driver The driver to enhance131 * @return A class implementing the described interfaces.132 */133 public WebDriver augment(WebDriver driver) {134 Require.nonNull("WebDriver", driver);135 Require.precondition(driver instanceof HasCapabilities, "Driver must have capabilities", driver);136 Capabilities caps = ImmutableCapabilities.copyOf(((HasCapabilities) driver).getCapabilities());137 // Collect the interfaces to apply138 List<Augmentation<?>> matchingAugmenters = augmentations.stream()139 // Only consider the augmenters that match interfaces we don't already implement140 .filter(augmentation -> !augmentation.interfaceClass.isAssignableFrom(driver.getClass()))141 // And which match an augmentation we have142 .filter(augmentation -> augmentation.whenMatches.test(caps))143 .collect(Collectors.toList());144 if (matchingAugmenters.isEmpty()) {145 return driver;146 }147 // Grab a remote execution method, if possible148 RemoteWebDriver remote = extractRemoteWebDriver(driver);149 ExecuteMethod execute = remote == null ?150 (commandName, parameters) -> { throw new WebDriverException("Cannot execute remote command: " + commandName); } :151 new RemoteExecuteMethod(remote);152 DynamicType.Builder<? extends WebDriver> builder = new ByteBuddy()153 .subclass(driver.getClass())154 .annotateType(AnnotationDescription.Builder.ofType(Augmentable.class).build())155 .method(named("isAugmented")).intercept(FixedValue.value(true));156 for (Augmentation<?> augmentation : augmentations) {157 Class<?> iface = augmentation.interfaceClass;158 Object instance = augmentation.implementation.apply(caps, execute);159 builder = builder.implement(iface)160 .method(anyOf(iface.getDeclaredMethods()))161 .intercept(MethodDelegation.to(instance, iface));162 }163 Class<? extends WebDriver> definition = builder.make()164 .load(driver.getClass().getClassLoader(), ClassLoadingStrategy.Default.WRAPPER)165 .getLoaded()166 .asSubclass(driver.getClass());167 try {168 WebDriver toReturn = definition.getDeclaredConstructor().newInstance();169 copyFields(driver.getClass(), driver, toReturn);170 return toReturn;171 } catch (ReflectiveOperationException e) {172 throw new IllegalStateException("Unable to create new proxy", e);173 }174 }175 private RemoteWebDriver extractRemoteWebDriver(WebDriver driver) {176 Require.nonNull("WebDriver", driver);177 if (driver instanceof RemoteWebDriver) {178 return (RemoteWebDriver) driver;179 }180 if (driver instanceof WrapsDriver) {181 return extractRemoteWebDriver(((WrapsDriver) driver).getWrappedDriver());182 }183 return null;184 }185 private void copyFields(Class<?> clazz, Object source, Object target) {186 if (Object.class.equals(clazz)) {187 // Stop!188 return;189 }190 for (Field field : clazz.getDeclaredFields()) {191 copyField(source, target, field);192 }193 copyFields(clazz.getSuperclass(), source, target);194 }195 private void copyField(Object source, Object target, Field field) {196 if (Modifier.isFinal(field.getModifiers())) {197 return;198 }199 try {200 field.setAccessible(true);201 Object value = field.get(source);202 field.set(target, value);203 } catch (IllegalAccessException e) {204 throw new RuntimeException(e);205 }206 }207 private static class Augmentation<X> {208 public final Predicate<Capabilities> whenMatches;209 public final Class<X> interfaceClass;210 public final BiFunction<Capabilities, ExecuteMethod, X> implementation;211 public Augmentation(212 Predicate<Capabilities> whenMatches,213 Class<X> interfaceClass,214 BiFunction<Capabilities, ExecuteMethod, X> implementation) {215 this.whenMatches = Require.nonNull("Capabilities predicate", whenMatches);216 this.interfaceClass = Require.nonNull("Interface to implement", interfaceClass);217 this.implementation = Require.nonNull("Interface implementation", implementation);218 Require.precondition(219 interfaceClass.isInterface(),220 "%s must be an interface, not a concrete class",221 interfaceClass);222 }223 }224}...